From c1875934181153e19965a68ecb36028314ec7406 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 13 Jun 2024 17:48:49 -0700 Subject: [PATCH 001/171] Prevent expansion of cat indexing to avoid int64 intermediate (#127815) Fix for https://github.com/pytorch/pytorch/issues/127652 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815 Approved by: https://github.com/shunting314, https://github.com/peterbell10 --- test/inductor/test_cuda_repro.py | 48 +++++++++++++++++++++++++++++- torch/_inductor/bounds.py | 9 ++++++ torch/_inductor/codegen/common.py | 3 ++ torch/_inductor/lowering.py | 12 ++++++-- torch/_inductor/utils.py | 10 +++++-- torch/utils/_sympy/functions.py | 16 ++++++++++ torch/utils/_sympy/interp.py | 2 ++ torch/utils/_sympy/value_ranges.py | 4 +++ 8 files changed, 99 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index eb5543a053ba44..8d0b6652be1e3b 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -18,7 +18,10 @@ from torch._inductor.utils import run_and_get_code from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, + SM80OrLater, +) from torch.testing._internal.common_utils import ( DeterministicGuard, freeze_rng_state, @@ -27,6 +30,8 @@ TEST_WITH_ASAN, ) +from torch.testing._internal.inductor_utils import skipCUDAIf + try: try: import triton @@ -1239,6 +1244,47 @@ def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 ) + @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") + def test_int64_index_intermediate(self): + def foo(inp): + view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192]) + split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1) + view_23 = None + getitem_17 = split_1[0] + getitem_18 = split_1[1] + getitem_19 = split_1[2] + getitem_20 = split_1[3] + getitem_21 = split_1[4] + getitem_22 = split_1[5] + getitem_23 = split_1[6] + getitem_24 = split_1[7] + split_1 = None + cat_1 = torch.ops.aten.cat.default( + [ + getitem_17, + getitem_18, + getitem_19, + getitem_20, + getitem_21, + getitem_22, + getitem_23, + getitem_24, + ] + ) + getitem_17 = ( + getitem_18 + ) = ( + getitem_19 + ) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None + return cat_1 + + for mark_dynamic in [False, True]: + inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda") + if mark_dynamic: + torch._dynamo.mark_dynamic(inp, 0) + foo_c = torch.compile(foo) + torch.testing.assert_allclose(foo(inp), foo_c(inp)) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 8c62ef2ba3c94b..b7bb37e5ee68f5 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -45,6 +45,15 @@ def upper_bound(v): # To access this variable call `get_bounds()` self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + @cache_on_self def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: submodules = self.swap_submodules(self.loop_body.submodules) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 78f284e03d37e9..295912ab50bf4c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -451,6 +451,9 @@ def _print_FloatTrueDiv(self, expr): def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) + def _print_Identity(self, expr): + return self._print(expr.args[0]) + def _print_GreaterThan(self, expr): # GreaterThan: >= # StrictlyGreaterThan: > diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 2ca80b16ab0cdf..449e512352fa1e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -35,7 +35,13 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + IntTrueDiv, + ModularIndexing, +) from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -1021,7 +1027,9 @@ def inner_fn(idx): # if we're concatting [4], [2] # when we index the second tensor for 5 we want to index 5 - 4 - idx_load[dim] -= inputs_ranges[i][0] + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) masked_loads.append( ops.masked( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 60c59a24d71955..51216e09522572 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -53,7 +53,13 @@ from torch.autograd.profiler_util import EventList from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import ShapeProp -from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) from torch.utils._sympy.symbol import make_symbol, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from . import config @@ -574,7 +580,7 @@ def sympy_str(expr: sympy.Expr) -> str: if isinstance(expr, sympy.Mul): return " * ".join(map(sympy_str, expr.args)) - if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): + if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" return str(expr) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index fd9921848d6092..3c845f58117bc2 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -24,6 +24,7 @@ "ToFloat", "FloatPow", "PowByNatural", + "Identity", ] @@ -719,6 +720,21 @@ def eval(cls, number): return -sympy.oo +class Identity(sympy.Function): + """ + Prevents expansion and other optimizations + """ + + def __repr__(self): + return f"Identity({self.args[0]})" + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_integer(self): + return self.args[0].is_integer # type: ignore[attr-defined] + + def make_opaque_unary_fn(name): class OpaqueUnaryFn(sympy.Function): """ diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 36ff6fc23d4a94..3bcb369bcebcd7 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -23,6 +23,7 @@ FloatTrueDiv, FloorDiv, FloorToInt, + Identity, IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, @@ -92,6 +93,7 @@ def handlers(): ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", + Identity: "identity", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", RoundDecimal: "round_decimal", } diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 7c2f36d159d6a5..c1ed0b02946dfe 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -473,6 +473,10 @@ def eq(a, b): def ne(cls, a, b): return cls.not_(cls.eq(a, b)) + @classmethod + def identity(cls, a): + return ValueRanges.wrap(a) + @classmethod def lt(cls, a, b): a = ValueRanges.wrap(a) From ee140a198fbe8e7edb5f51806f63e9fa80f523fe Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 14 Jun 2024 15:51:00 +0000 Subject: [PATCH 002/171] Revert "[Port][Quant][Inductor] Bug fix: mutation nodes not handled correctly for QLinearPointwiseBinaryPT2E (#128591)" This reverts commit 03e8a4cf45ee45611de77b55b515a8936f60ce31. Reverted https://github.com/pytorch/pytorch/pull/128591 on behalf of https://github.com/atalman due to Contains release only changes should not be landed ([comment](https://github.com/pytorch/pytorch/pull/128591#issuecomment-2168308233)) --- .ci/pytorch/common_utils.sh | 2 +- .circleci/scripts/binary_populate_env.sh | 4 +- .github/ci_commit_pins/xla.txt | 2 +- .github/scripts/filter_test_configs.py | 4 +- .github/templates/common.yml.j2 | 2 +- .../linux_binary_build_workflow.yml.j2 | 4 +- .../macos_binary_build_workflow.yml.j2 | 4 +- .../windows_binary_build_workflow.yml.j2 | 8 +- .github/workflows/_android-build-test.yml | 12 +- .../workflows/_android-full-build-test.yml | 12 +- .github/workflows/_bazel-build-test.yml | 14 +- .github/workflows/_binary-build-linux.yml | 11 +- .github/workflows/_binary-test-linux.yml | 13 +- .github/workflows/_binary-upload.yml | 2 +- .github/workflows/_buck-build-test.yml | 6 +- .github/workflows/_docs.yml | 10 +- .github/workflows/_ios-build-test.yml | 6 +- .github/workflows/_linux-build-label.yml | 4 +- .github/workflows/_linux-build-rg.yml | 2 +- .github/workflows/_linux-build.yml | 10 +- .github/workflows/_linux-test-label.yml | 2 +- .github/workflows/_linux-test-rg.yml | 2 +- .github/workflows/_linux-test.yml | 12 +- .github/workflows/_mac-build.yml | 10 +- .github/workflows/_mac-test-mps.yml | 6 +- .github/workflows/_mac-test.yml | 8 +- .github/workflows/_rocm-test.yml | 6 +- .github/workflows/_run_android_tests.yml | 6 +- .github/workflows/_runner-determinator.yml | 2 +- .github/workflows/_win-build.yml | 6 +- .github/workflows/_win-test.yml | 6 +- .github/workflows/_xpu-test.yml | 6 +- .github/workflows/build-triton-wheel.yml | 34 ++- .github/workflows/check-labels.yml | 2 +- .../close-nonexistent-disable-issues.yml | 2 +- .github/workflows/docker-builds.yml | 8 +- .github/workflows/docker-release.yml | 8 +- ...linux-aarch64-binary-manywheel-nightly.yml | 50 ++-- .../generated-linux-binary-conda-nightly.yml | 120 ++++----- ...d-linux-binary-libtorch-cxx11-abi-main.yml | 4 +- ...inux-binary-libtorch-cxx11-abi-nightly.yml | 46 ++-- ...d-linux-binary-libtorch-pre-cxx11-main.yml | 4 +- ...inux-binary-libtorch-pre-cxx11-nightly.yml | 46 ++-- .../generated-linux-binary-manywheel-main.yml | 12 +- ...nerated-linux-binary-manywheel-nightly.yml | 238 +++++++++--------- ...d-linux-s390x-binary-manywheel-nightly.yml | 30 +-- ...rated-macos-arm64-binary-conda-nightly.yml | 25 +- ...rm64-binary-libtorch-cxx11-abi-nightly.yml | 5 +- ...rated-macos-arm64-binary-wheel-nightly.yml | 25 +- ...generated-windows-binary-conda-nightly.yml | 120 ++++++--- ...ted-windows-binary-libtorch-debug-main.yml | 6 +- ...-windows-binary-libtorch-debug-nightly.yml | 24 +- ...d-windows-binary-libtorch-release-main.yml | 6 +- ...indows-binary-libtorch-release-nightly.yml | 24 +- ...generated-windows-binary-wheel-nightly.yml | 120 ++++++--- .github/workflows/lint-bc.yml | 2 +- .github/workflows/lint.yml | 19 +- .github/workflows/llm_td_retrieval.yml | 2 +- .github/workflows/nightly-rockset-uploads.yml | 2 +- .github/workflows/nightly.yml | 6 +- .../target-determination-indexer.yml | 8 +- .github/workflows/target_determination.yml | 2 +- .github/workflows/update-viablestrict.yml | 2 +- .github/workflows/update_pytorch_labels.yml | 2 +- .github/workflows/upload-alerts.yml | 2 +- .github/workflows/upload-test-stats.yml | 2 +- .../upload-torch-dynamo-perf-stats.yml | 2 +- .../upload_test_stats_intermediate.yml | 2 +- .github/workflows/weekly.yml | 4 +- tools/stats/import_test_stats.py | 4 +- 70 files changed, 681 insertions(+), 543 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 68a7d5559f5a1c..51297f7bfff886 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -178,7 +178,7 @@ function install_torchrec_and_fbgemm() { function clone_pytorch_xla() { if [[ ! -d ./xla ]]; then - git clone --recursive -b r2.4 https://github.com/pytorch/xla.git + git clone --recursive --quiet https://github.com/pytorch/xla.git pushd xla # pin the xla hash so that we don't get broken by changes to xla git checkout "$(cat ../.github/ci_commit_pins/xla.txt)" diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 437fee524926db..a45a2c9754ba5f 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -79,7 +79,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS: # Only linux Python < 3.13 are supported wheels for triton TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'" TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" - if [[ -n "$PYTORCH_BUILD_VERSION" ]]; then + if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" fi @@ -89,7 +89,7 @@ fi # Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}" - if [[ -n "$PYTORCH_BUILD_VERSION" ]]; then + if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}" fi diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index ac6456c938d23f..c1a5561c5308f9 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -r2.4 +6f0b61e5d782913a0fc7743812f2a8e522189111 diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index c5f664be1b454c..c2e45bac81100b 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -38,9 +38,9 @@ def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool: } # The link to the published list of disabled jobs -DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=tIl0Qo224T_NDVw0dtG4hU1cZJM97inV" +DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json" # and unstable jobs -UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=GPyRZRsOo26Gfk_WjAoNNxEMGXkIxIes" +UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json" # Some constants used to handle disabled and unstable jobs JOB_NAME_SEP = "/" diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 8c82cec0268f5b..38b90e919c3341 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -8,7 +8,7 @@ # NOTE: If testing pytorch/builder changes you can change this variable to change what pytorch/builder reference # the binary builds will check out {%- set builder_repo = "pytorch/builder" -%} -{%- set builder_branch = "release/2.4" -%} +{%- set builder_branch = "main" -%} {%- macro concurrency(build_environment) -%} concurrency: diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index b7688dbdd82a8d..c5903005aa1ba3 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -113,8 +113,8 @@ jobs: with: name: !{{ config["build_name"] }} path: "${{ runner.temp }}/artifacts/" - !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch, checkout_pr_head=False) }} + !{{ common.checkout(deep_clone=False, directory="pytorch") }} + !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch) }} - name: ROCm set GPU_FLAG run: | echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 2649146a326bea..591dc52ef9c011 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -81,8 +81,8 @@ jobs: elif [ -d "/Applications/Xcode_13.3.1.app" ]; then echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" fi - !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch, checkout_pr_head=False) }} + !{{ common.checkout(deep_clone=False, directory="pytorch") }} + !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch) }} - name: Install sccache (only for non-forked PRs, and pushes to trunk) uses: nick-fields/retry@v2.8.2 if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index c9815fe4bd03f6..d5aca578b9024a 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -65,8 +65,8 @@ jobs: steps: !{{ common.setup_ec2_windows() }} !{{ set_runner_specific_vars() }} - !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch, checkout_pr_head=False) }} + !{{ common.checkout(deep_clone=False, directory="pytorch") }} + !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch) }} - name: Populate binary env shell: bash run: | @@ -105,8 +105,8 @@ jobs: with: name: !{{ config["build_name"] }} path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - !{{ common.checkout(deep_clone=False, directory="pytorch", checkout_pr_head=False) }} - !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch, checkout_pr_head=False) }} + !{{ common.checkout(deep_clone=False, directory="pytorch") }} + !{{ common.checkout(deep_clone=False, directory="builder", repository=common.builder_repo, branch=common.builder_branch) }} - name: Populate binary env shell: bash run: | diff --git a/.github/workflows/_android-build-test.yml b/.github/workflows/_android-build-test.yml index 7238870a96096b..d599e769b8b6a0 100644 --- a/.github/workflows/_android-build-test.yml +++ b/.github/workflows/_android-build-test.yml @@ -37,7 +37,7 @@ jobs: keep-going: ${{ steps.filter.outputs.keep-going }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -59,25 +59,25 @@ jobs: runs-on: ${{ matrix.runner }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image-name }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -141,5 +141,5 @@ jobs: if: always() - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_android-full-build-test.yml b/.github/workflows/_android-full-build-test.yml index b4dd46b23951fc..7a0c4377eca4e6 100644 --- a/.github/workflows/_android-full-build-test.yml +++ b/.github/workflows/_android-full-build-test.yml @@ -37,7 +37,7 @@ jobs: keep-going: ${{ steps.filter.outputs.keep-going }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -59,25 +59,25 @@ jobs: runs-on: ${{ matrix.runner }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image-name }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -186,5 +186,5 @@ jobs: if: always() - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index 1fcff5d02647e5..ca65ce64bc657d 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -42,7 +42,7 @@ jobs: reenabled-issues: ${{ steps.filter.outputs.reenabled-issues }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -64,25 +64,25 @@ jobs: runs-on: ${{ matrix.runner }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image-name }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -92,7 +92,7 @@ jobs: run: echo "IN_ARC_RUNNER=$([ -f /.inarc ] && echo true || echo false)" >> "$GITHUB_OUTPUT" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: ${{ inputs.cuda-version != 'cpu' && steps.check_arc_runner.outputs.IN_ARC_RUNNER == 'false' }} - name: Output disk space left @@ -201,5 +201,5 @@ jobs: file-suffix: bazel-${{ github.job }}_${{ steps.get-job-id.outputs.job-id }} - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index 4200b946bd41e7..e54cb4a14c9016 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -145,13 +145,13 @@ jobs: - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" if: inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main continue-on-error: true with: github-secret: ${{ secrets.github-token }} - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: ${{ inputs.build_environment == 'linux-aarch64-binary-manywheel' || inputs.build_environment == 'linux-s390x-binary-manywheel' }} @@ -181,6 +181,7 @@ jobs: - name: Checkout PyTorch to pytorch dir uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -194,7 +195,7 @@ jobs: - name: Checkout pytorch/builder to builder dir uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -220,7 +221,7 @@ jobs: - name: Pull Docker image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ inputs.DOCKER_IMAGE }} @@ -277,7 +278,7 @@ jobs: - name: Teardown Linux if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main - name: Chown workspace if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 2c92128b1aedc3..25a6b24223f99e 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -128,14 +128,14 @@ jobs: - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" if: inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main continue-on-error: true with: github-secret: ${{ secrets.github-token }} # Setup the environment - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: ${{ inputs.build_environment == 'linux-aarch64-binary-manywheel' || inputs.build_environment == 'linux-s390x-binary-manywheel' }} @@ -158,6 +158,7 @@ jobs: - name: Checkout PyTorch to pytorch dir uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch @@ -170,7 +171,7 @@ jobs: - name: Checkout pytorch/builder to builder dir uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -201,12 +202,12 @@ jobs: path: "${{ runner.temp }}/artifacts/" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' && steps.filter.outputs.is-test-matrix-empty == 'False' }} - name: Pull Docker image if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' && inputs.build_environment != 'linux-s390x-binary-manywheel' }} - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ inputs.DOCKER_IMAGE }} @@ -216,7 +217,7 @@ jobs: - name: Teardown Linux if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main - name: Chown workspace if: always() && inputs.build_environment != 'linux-s390x-binary-manywheel' diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index af990e3ec71a3c..1231dd0e8c7d46 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -95,7 +95,7 @@ jobs: SHA1: ${{ github.event.pull_request.head.sha || github.sha }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: true diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index 4af0d2a29a3779..43eb72fc9181b2 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -23,7 +23,7 @@ jobs: keep-going: ${{ steps.filter.outputs.keep-going }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -44,7 +44,7 @@ jobs: runs-on: ${{ matrix.runner }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Set up JDK 8 uses: actions/setup-java@v3 @@ -53,7 +53,7 @@ jobs: distribution: 'temurin' - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: 3.8 environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index fd978f9b86b2b7..069bcb4d2a14e4 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -80,7 +80,7 @@ jobs: name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -91,7 +91,7 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux @@ -106,12 +106,12 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -218,5 +218,5 @@ jobs: s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/functorchdocs - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index 9269f719ecbba8..0282a0482104d9 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -46,7 +46,7 @@ jobs: keep-going: ${{ steps.filter.outputs.keep-going }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -80,7 +80,7 @@ jobs: steps: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Populate CI build options shell: bash @@ -102,7 +102,7 @@ jobs: brew install libtool - name: Setup miniconda for iOS - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: "3.9" environment-file: .github/requirements/conda-env-iOS.txt diff --git a/.github/workflows/_linux-build-label.yml b/.github/workflows/_linux-build-label.yml index 38908f869e4a06..427f993b48530a 100644 --- a/.github/workflows/_linux-build-label.yml +++ b/.github/workflows/_linux-build-label.yml @@ -81,7 +81,7 @@ jobs: test-matrix: ${{ steps.linux-build.outputs.test-matrix }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -90,7 +90,7 @@ jobs: # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Linux Build id: linux-build diff --git a/.github/workflows/_linux-build-rg.yml b/.github/workflows/_linux-build-rg.yml index 3df755b4425207..6c6a4827e16723 100644 --- a/.github/workflows/_linux-build-rg.yml +++ b/.github/workflows/_linux-build-rg.yml @@ -86,7 +86,7 @@ jobs: # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Linux Build id: linux-build diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index cc38698e5b37ef..c3bcb0d888dfc8 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -90,7 +90,7 @@ jobs: test-matrix: ${{ steps.filter.outputs.test-matrix }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -99,7 +99,7 @@ jobs: # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux @@ -114,7 +114,7 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image-name }} @@ -128,7 +128,7 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -238,5 +238,5 @@ jobs: s3-bucket: ${{ inputs.s3-bucket }} - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/_linux-test-label.yml b/.github/workflows/_linux-test-label.yml index 3bca19ccfbd699..7056c0168a19ed 100644 --- a/.github/workflows/_linux-test-label.yml +++ b/.github/workflows/_linux-test-label.yml @@ -67,7 +67,7 @@ jobs: timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Linux Test id: linux-test diff --git a/.github/workflows/_linux-test-rg.yml b/.github/workflows/_linux-test-rg.yml index 62d5cf97422faa..6dc2f6c63bf3e3 100644 --- a/.github/workflows/_linux-test-rg.yml +++ b/.github/workflows/_linux-test-rg.yml @@ -68,7 +68,7 @@ jobs: timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Linux Test id: linux-test diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 999feaa83f6846..5f3f290dd31daf 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -67,7 +67,7 @@ jobs: timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main if: ${{ !contains(matrix.runner, 'gcp.a100') }} with: github-secret: ${{ secrets.GITHUB_TOKEN }} @@ -76,7 +76,7 @@ jobs: docker exec -it $(docker container ps --format '{{.ID}}') bash - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux @@ -91,7 +91,7 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image }} @@ -105,7 +105,7 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -116,7 +116,7 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver - uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_arc_runner.outputs.IN_ARC_RUNNER == 'false' }} - name: Lock NVIDIA A100 40GB Frequency @@ -333,7 +333,7 @@ jobs: path: ./**/core.[1-9]* - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() # NB: We are currently having an intermittent GPU-related issue on G5 runners with diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 6f394cc9893a84..a27ddaf629b51f 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -71,11 +71,11 @@ jobs: test-matrix: ${{ steps.filter.outputs.test-matrix }} steps: - name: Clean up disk space before running MacOS workflow - uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.4 + uses: pytorch/test-infra/.github/actions/check-disk-space@main # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Set xcode version env: @@ -87,7 +87,7 @@ jobs: - name: Setup miniconda if: inputs.environment-file == '' - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: ${{ inputs.python-version }} environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} @@ -97,7 +97,7 @@ jobs: # environment even though the arch is x86-64 - name: Setup miniconda using the provided environment file if: inputs.environment-file != '' - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: ${{ inputs.python-version }} environment-file: ${{ inputs.environment-file }} @@ -207,4 +207,4 @@ jobs: - name: Clean up disk space if: always() continue-on-error: true - uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.4 + uses: pytorch/test-infra/.github/actions/check-disk-space@main diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index e445ab22fb85fc..2c0da2f8afd7c5 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -40,7 +40,7 @@ jobs: reenabled-issues: ${{ steps.filter.outputs.reenabled-issues }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false @@ -81,7 +81,7 @@ jobs: use-gha: true - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: ${{ inputs.python-version }} environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} @@ -159,4 +159,4 @@ jobs: - name: Clean up disk space if: always() continue-on-error: true - uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.4 + uses: pytorch/test-infra/.github/actions/check-disk-space@main diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 8f313a0212998d..3e82194ff3461c 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -74,11 +74,11 @@ jobs: done - name: Clean up disk space before running MacOS workflow - uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.4 + uses: pytorch/test-infra/.github/actions/check-disk-space@main # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Download build artifacts uses: ./.github/actions/download-build-artifacts @@ -93,7 +93,7 @@ jobs: use-gha: true - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: ${{ inputs.python-version }} environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} @@ -216,4 +216,4 @@ jobs: - name: Clean up disk space if: always() continue-on-error: true - uses: pytorch/test-infra/.github/actions/check-disk-space@release/2.4 + uses: pytorch/test-infra/.github/actions/check-disk-space@main diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index f2982777a5beb0..1f2d86273ee149 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -58,7 +58,7 @@ jobs: steps: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: true @@ -80,12 +80,12 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} diff --git a/.github/workflows/_run_android_tests.yml b/.github/workflows/_run_android_tests.yml index 883aed7c888682..b9b3d0645eac62 100644 --- a/.github/workflows/_run_android_tests.yml +++ b/.github/workflows/_run_android_tests.yml @@ -23,7 +23,7 @@ jobs: keep-going: ${{ steps.filter.outputs.keep-going }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false @@ -54,10 +54,10 @@ jobs: SUPPORT_ABI: '${{ matrix.support_abi }}' steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup miniconda - uses: pytorch/test-infra/.github/actions/setup-miniconda@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: 3.8 environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }}.txt diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 3d6230aa4632b4..c86f8b840145a9 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -35,7 +35,7 @@ jobs: USERNAME: ${{ inputs.user_name }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: true diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index e32e1bfb5bb86b..bc381c50628d10 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -60,10 +60,10 @@ jobs: git config --global core.fsmonitor false - name: Clean up leftover processes on non-ephemeral Windows runner - uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.4 + uses: pytorch/test-infra/.github/actions/cleanup-runner@main - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -78,7 +78,7 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: true diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 505c2119fb61ac..99d037f0355ce6 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -54,10 +54,10 @@ jobs: git config --global core.fsmonitor false - name: Clean up leftover processes on non-ephemeral Windows runner - uses: pytorch/test-infra/.github/actions/cleanup-runner@release/2.4 + uses: pytorch/test-infra/.github/actions/cleanup-runner@main - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} instructions: | @@ -73,7 +73,7 @@ jobs: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: true diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index bd57632c506245..d7af711f8adb4b 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -54,7 +54,7 @@ jobs: steps: # [see note: pytorch repo ref] - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup XPU uses: ./.github/actions/setup-xpu @@ -72,12 +72,12 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ inputs.docker-image }} - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index f2abec75028f8d..ddba8ff8907cc6 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -47,12 +47,12 @@ jobs: BUILD_DEVICE: ${{ matrix.device }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false @@ -60,11 +60,13 @@ jobs: uses: ./.github/actions/setup-linux - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ env.DOCKER_IMAGE }} - name: Build Triton wheel + env: + IS_RELEASE_TAG: ${{ startsWith(github.event.ref, 'refs/tags/v') }} run: | set -x mkdir -p "${RUNNER_TEMP}/artifacts/" @@ -105,9 +107,14 @@ jobs: BUILD_ROCM="--build-rocm" fi + RELEASE="" + if [[ "${IS_RELEASE_TAG}" == true ]]; then + RELEASE="--release" + fi + docker exec -t "${container_name}" yum install -y zlib-devel zip docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==67.4.0 - docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" /pytorch/.github/scripts/build_triton_wheel.py $BUILD_ROCM + docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" /pytorch/.github/scripts/build_triton_wheel.py $BUILD_ROCM $RELEASE docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - uses: actions/upload-artifact@v3 @@ -117,7 +124,7 @@ jobs: path: ${{ runner.temp }}/artifacts/* - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() upload-wheel: @@ -202,12 +209,12 @@ jobs: PY_VERS: ${{ matrix.py_vers }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false @@ -215,11 +222,13 @@ jobs: uses: ./.github/actions/setup-linux - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ env.DOCKER_IMAGE }} - name: Build Triton conda package + env: + IS_RELEASE_TAG: ${{ startsWith(github.event.ref, 'refs/tags/v') }} run: | set -x mkdir -p "${RUNNER_TEMP}/artifacts/" @@ -232,8 +241,13 @@ jobs: "${DOCKER_IMAGE}" \ ) + RELEASE="" + if [[ "${IS_RELEASE_TAG}" == true ]]; then + RELEASE="--release" + fi + docker exec -t "${container_name}" yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel - docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" + docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" $RELEASE docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - uses: actions/upload-artifact@v3 @@ -243,7 +257,7 @@ jobs: path: ${{ runner.temp }}/artifacts/* - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() upload-conda: diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index d0dc4225320102..d638d588504f2e 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -31,7 +31,7 @@ jobs: runs-on: linux.20_04.4x steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/close-nonexistent-disable-issues.yml b/.github/workflows/close-nonexistent-disable-issues.yml index 22d1c29293aff0..12a6facbaabc5c 100644 --- a/.github/workflows/close-nonexistent-disable-issues.yml +++ b/.github/workflows/close-nonexistent-disable-issues.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Run close_nonexistent_disable_issues.py env: diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 6d46db541294f3..1b3d613829f761 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -79,21 +79,21 @@ jobs: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required for git merge-base - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - name: Setup Linux uses: ./.github/actions/setup-linux - name: Build docker image id: build-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ${{ matrix.docker-image-name }} always-rebuild: true push: true - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.build-docker-image.outputs.docker-image }} @@ -125,5 +125,5 @@ jobs: if: always() - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 01ab7b2b2c434f..351497bee7537c 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -41,7 +41,7 @@ jobs: matrix: ${{ steps.generate-matrix.outputs.matrix }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: true @@ -69,7 +69,7 @@ jobs: CUDNN_VERSION: ${{ matrix.cudnn_version }} steps: - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] @@ -147,12 +147,12 @@ jobs: fi - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() validate: needs: build - uses: pytorch/builder/.github/workflows/validate-docker-images.yml@release/2.4 + uses: pytorch/builder/.github/workflows/validate-docker-images.yml@main with: channel: nightly ref: main diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index b2db998ead6046..a1a7e6fd9537e6 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.8" runs_on: linux.arm64.m7g.4xlarge ALPINE_IMAGE: "arm64v8/alpine" @@ -69,7 +69,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -91,7 +91,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-aarch64 secrets: @@ -111,7 +111,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.8" runs_on: linux.arm64.m7g.4xlarge @@ -135,7 +135,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda-aarch64 @@ -156,7 +156,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.9" runs_on: linux.arm64.m7g.4xlarge ALPINE_IMAGE: "arm64v8/alpine" @@ -177,7 +177,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -199,7 +199,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 secrets: @@ -219,7 +219,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" runs_on: linux.arm64.m7g.4xlarge @@ -243,7 +243,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda-aarch64 @@ -264,7 +264,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.10" runs_on: linux.arm64.m7g.4xlarge ALPINE_IMAGE: "arm64v8/alpine" @@ -285,7 +285,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -307,7 +307,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: @@ -327,7 +327,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" runs_on: linux.arm64.m7g.4xlarge @@ -351,7 +351,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64 @@ -372,7 +372,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.11" runs_on: linux.arm64.m7g.4xlarge ALPINE_IMAGE: "arm64v8/alpine" @@ -393,7 +393,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -415,7 +415,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: @@ -435,7 +435,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" runs_on: linux.arm64.m7g.4xlarge @@ -459,7 +459,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64 @@ -480,7 +480,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.12" runs_on: linux.arm64.m7g.4xlarge ALPINE_IMAGE: "arm64v8/alpine" @@ -501,7 +501,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -523,7 +523,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: @@ -543,7 +543,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" runs_on: linux.arm64.m7g.4xlarge @@ -567,7 +567,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu124 GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64 diff --git a/.github/workflows/generated-linux-binary-conda-nightly.yml b/.github/workflows/generated-linux-binary-conda-nightly.yml index 2291bfb47855d2..50a6d986255f74 100644 --- a/.github/workflows/generated-linux-binary-conda-nightly.yml +++ b/.github/workflows/generated-linux-binary-conda-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cpu build_environment: linux-binary-conda @@ -66,7 +66,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cpu build_environment: linux-binary-conda @@ -87,7 +87,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cpu secrets: @@ -108,7 +108,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.8" runs_on: linux.24xlarge build_name: conda-py3_8-cuda11_8 @@ -128,7 +128,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda11_8 build_environment: linux-binary-conda @@ -150,7 +150,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda11_8 secrets: @@ -171,7 +171,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.8" runs_on: linux.24xlarge build_name: conda-py3_8-cuda12_1 @@ -191,7 +191,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda12_1 build_environment: linux-binary-conda @@ -213,7 +213,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda12_1 secrets: @@ -234,7 +234,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.8" runs_on: linux.24xlarge build_name: conda-py3_8-cuda12_4 @@ -254,7 +254,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda12_4 build_environment: linux-binary-conda @@ -276,7 +276,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cuda12_4 secrets: @@ -296,7 +296,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu build_environment: linux-binary-conda @@ -314,7 +314,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu build_environment: linux-binary-conda @@ -335,7 +335,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu secrets: @@ -356,7 +356,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.9" runs_on: linux.24xlarge build_name: conda-py3_9-cuda11_8 @@ -376,7 +376,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda11_8 build_environment: linux-binary-conda @@ -398,7 +398,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda11_8 secrets: @@ -419,7 +419,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.9" runs_on: linux.24xlarge build_name: conda-py3_9-cuda12_1 @@ -439,7 +439,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_1 build_environment: linux-binary-conda @@ -461,7 +461,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_1 secrets: @@ -482,7 +482,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.9" runs_on: linux.24xlarge build_name: conda-py3_9-cuda12_4 @@ -502,7 +502,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_4 build_environment: linux-binary-conda @@ -524,7 +524,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_4 secrets: @@ -544,7 +544,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu build_environment: linux-binary-conda @@ -562,7 +562,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu build_environment: linux-binary-conda @@ -583,7 +583,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu secrets: @@ -604,7 +604,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.10" runs_on: linux.24xlarge build_name: conda-py3_10-cuda11_8 @@ -624,7 +624,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda11_8 build_environment: linux-binary-conda @@ -646,7 +646,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda11_8 secrets: @@ -667,7 +667,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.10" runs_on: linux.24xlarge build_name: conda-py3_10-cuda12_1 @@ -687,7 +687,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_1 build_environment: linux-binary-conda @@ -709,7 +709,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_1 secrets: @@ -730,7 +730,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.10" runs_on: linux.24xlarge build_name: conda-py3_10-cuda12_4 @@ -750,7 +750,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_4 build_environment: linux-binary-conda @@ -772,7 +772,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_4 secrets: @@ -792,7 +792,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu build_environment: linux-binary-conda @@ -810,7 +810,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu build_environment: linux-binary-conda @@ -831,7 +831,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu secrets: @@ -852,7 +852,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.11" runs_on: linux.24xlarge build_name: conda-py3_11-cuda11_8 @@ -872,7 +872,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda11_8 build_environment: linux-binary-conda @@ -894,7 +894,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda11_8 secrets: @@ -915,7 +915,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.11" runs_on: linux.24xlarge build_name: conda-py3_11-cuda12_1 @@ -935,7 +935,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_1 build_environment: linux-binary-conda @@ -957,7 +957,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_1 secrets: @@ -978,7 +978,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.11" runs_on: linux.24xlarge build_name: conda-py3_11-cuda12_4 @@ -998,7 +998,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_4 build_environment: linux-binary-conda @@ -1020,7 +1020,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_4 secrets: @@ -1040,7 +1040,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu build_environment: linux-binary-conda @@ -1058,7 +1058,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu build_environment: linux-binary-conda @@ -1079,7 +1079,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu secrets: @@ -1100,7 +1100,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.12" runs_on: linux.24xlarge build_name: conda-py3_12-cuda11_8 @@ -1120,7 +1120,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda11_8 build_environment: linux-binary-conda @@ -1142,7 +1142,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda11_8 secrets: @@ -1163,7 +1163,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.12" runs_on: linux.24xlarge build_name: conda-py3_12-cuda12_1 @@ -1183,7 +1183,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_1 build_environment: linux-binary-conda @@ -1205,7 +1205,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_1 secrets: @@ -1226,7 +1226,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.12" runs_on: linux.24xlarge build_name: conda-py3_12-cuda12_4 @@ -1246,7 +1246,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_4 build_environment: linux-binary-conda @@ -1268,7 +1268,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_4 secrets: diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml index c1c7ff017c59f4..5577a5e7d9c3a0 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml @@ -43,7 +43,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi @@ -62,7 +62,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 9ba8292f6edca1..d400e82249867b 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi @@ -67,7 +67,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi @@ -89,7 +89,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi @@ -111,7 +111,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi @@ -131,7 +131,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi @@ -154,7 +154,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi @@ -176,7 +176,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi @@ -196,7 +196,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi @@ -219,7 +219,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi @@ -241,7 +241,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi @@ -261,7 +261,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi @@ -284,7 +284,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi @@ -306,7 +306,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-rocm6_0-shared-with-deps-cxx11-abi @@ -328,7 +328,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -342,6 +342,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -353,7 +354,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -369,7 +370,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm6.0-2.4 + docker-image: pytorch/libtorch-cxx11-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -389,7 +390,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-rocm6_0-shared-with-deps-cxx11-abi @@ -411,7 +412,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi @@ -433,7 +434,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -447,6 +448,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -458,7 +460,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -474,7 +476,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm6.1-2.4 + docker-image: pytorch/libtorch-cxx11-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -494,7 +496,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml index a6efc06fcf00fa..0158860d6f9428 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml @@ -43,7 +43,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 @@ -62,7 +62,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index 73dcb085642267..3205c3c78dad4b 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 @@ -67,7 +67,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 @@ -89,7 +89,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 @@ -111,7 +111,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 @@ -131,7 +131,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 @@ -154,7 +154,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 @@ -176,7 +176,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 @@ -196,7 +196,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 @@ -219,7 +219,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 @@ -241,7 +241,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 @@ -261,7 +261,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 @@ -284,7 +284,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 @@ -306,7 +306,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-rocm6_0-shared-with-deps-pre-cxx11 @@ -328,7 +328,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -342,6 +342,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -353,7 +354,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -369,7 +370,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.0-2.4 + docker-image: pytorch/manylinux-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -389,7 +390,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-rocm6_0-shared-with-deps-pre-cxx11 @@ -411,7 +412,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 @@ -433,7 +434,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -447,6 +448,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -458,7 +460,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -474,7 +476,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.1-2.4 + docker-image: pytorch/manylinux-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -494,7 +496,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 024fa01d34d77e..053877b1c90eaf 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -44,7 +44,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel @@ -64,7 +64,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel @@ -84,7 +84,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel @@ -104,7 +104,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel @@ -124,7 +124,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel @@ -144,7 +144,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index a03e9369407ef5..9d59728bbbbb5e 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu build_environment: linux-binary-manywheel @@ -66,7 +66,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu build_environment: linux-binary-manywheel @@ -87,7 +87,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu secrets: @@ -107,7 +107,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-cxx11-abi @@ -126,7 +126,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-cxx11-abi @@ -148,7 +148,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-cxx11-abi @@ -170,7 +170,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel @@ -190,7 +190,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 build_environment: linux-binary-manywheel @@ -212,7 +212,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_8 secrets: @@ -233,7 +233,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel @@ -253,7 +253,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 build_environment: linux-binary-manywheel @@ -275,7 +275,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_1 secrets: @@ -296,7 +296,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel @@ -316,7 +316,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 build_environment: linux-binary-manywheel @@ -338,7 +338,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda12_4 secrets: @@ -359,7 +359,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-rocm6_0 build_environment: linux-binary-manywheel @@ -380,7 +380,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.8" steps: - name: Setup ROCm @@ -393,6 +393,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -404,7 +405,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -420,7 +421,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.0-2.4 + docker-image: pytorch/manylinux-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -440,7 +441,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-rocm6_0 secrets: @@ -461,7 +462,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-rocm6_1 build_environment: linux-binary-manywheel @@ -482,7 +483,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.8" steps: - name: Setup ROCm @@ -495,6 +496,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -506,7 +508,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -522,7 +524,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.1-2.4 + docker-image: pytorch/manylinux-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -542,7 +544,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-rocm6_1 secrets: @@ -562,7 +564,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel @@ -580,7 +582,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel @@ -601,7 +603,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu secrets: @@ -621,7 +623,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi @@ -640,7 +642,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi @@ -662,7 +664,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi @@ -684,7 +686,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel @@ -704,7 +706,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel @@ -726,7 +728,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 secrets: @@ -747,7 +749,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel @@ -767,7 +769,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel @@ -789,7 +791,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 secrets: @@ -810,7 +812,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel @@ -830,7 +832,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel @@ -852,7 +854,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 secrets: @@ -873,7 +875,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_0 build_environment: linux-binary-manywheel @@ -894,7 +896,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -907,6 +909,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -918,7 +921,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -934,7 +937,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.0-2.4 + docker-image: pytorch/manylinux-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -954,7 +957,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_0 secrets: @@ -975,7 +978,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_1 build_environment: linux-binary-manywheel @@ -996,7 +999,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm @@ -1009,6 +1012,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1020,7 +1024,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1036,7 +1040,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.1-2.4 + docker-image: pytorch/manylinux-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -1056,7 +1060,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_1 secrets: @@ -1076,7 +1080,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -1094,7 +1098,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel @@ -1115,7 +1119,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -1135,7 +1139,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi @@ -1154,7 +1158,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi @@ -1176,7 +1180,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi @@ -1198,7 +1202,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel @@ -1218,7 +1222,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel @@ -1240,7 +1244,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 secrets: @@ -1261,7 +1265,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel @@ -1281,7 +1285,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel @@ -1303,7 +1307,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_1 secrets: @@ -1324,7 +1328,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel @@ -1344,7 +1348,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel @@ -1366,7 +1370,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 secrets: @@ -1387,7 +1391,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_0 build_environment: linux-binary-manywheel @@ -1408,7 +1412,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1421,6 +1425,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1432,7 +1437,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1448,7 +1453,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.0-2.4 + docker-image: pytorch/manylinux-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -1468,7 +1473,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_0 secrets: @@ -1489,7 +1494,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_1 build_environment: linux-binary-manywheel @@ -1510,7 +1515,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm @@ -1523,6 +1528,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1534,7 +1540,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1550,7 +1556,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.1-2.4 + docker-image: pytorch/manylinux-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -1570,7 +1576,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_1 secrets: @@ -1590,7 +1596,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -1608,7 +1614,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel @@ -1629,7 +1635,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -1649,7 +1655,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi @@ -1668,7 +1674,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi @@ -1690,7 +1696,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi @@ -1712,7 +1718,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel @@ -1732,7 +1738,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel @@ -1754,7 +1760,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 secrets: @@ -1775,7 +1781,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel @@ -1795,7 +1801,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel @@ -1817,7 +1823,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_1 secrets: @@ -1838,7 +1844,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel @@ -1858,7 +1864,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel @@ -1880,7 +1886,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda12_4 secrets: @@ -1901,7 +1907,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_0 build_environment: linux-binary-manywheel @@ -1922,7 +1928,7 @@ jobs: GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -1935,6 +1941,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1946,7 +1953,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1962,7 +1969,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.0-2.4 + docker-image: pytorch/manylinux-builder:rocm6.0-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -1982,7 +1989,7 @@ jobs: DESIRED_CUDA: rocm6.0 GPU_ARCH_VERSION: 6.0 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_0 secrets: @@ -2003,7 +2010,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_1 build_environment: linux-binary-manywheel @@ -2024,7 +2031,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm @@ -2037,6 +2044,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2048,7 +2056,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2064,7 +2072,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm6.1-2.4 + docker-image: pytorch/manylinux-builder:rocm6.1-main - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Teardown ROCm @@ -2084,7 +2092,7 @@ jobs: DESIRED_CUDA: rocm6.1 GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_1 secrets: @@ -2104,7 +2112,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -2122,7 +2130,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel @@ -2143,7 +2151,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -2163,7 +2171,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi @@ -2182,7 +2190,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi @@ -2204,7 +2212,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu-cxx11-abi GPU_ARCH_TYPE: cpu-cxx11-abi - DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-2.4 + DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi @@ -2226,7 +2234,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel @@ -2246,7 +2254,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel @@ -2268,7 +2276,7 @@ jobs: DESIRED_CUDA: cu118 GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 secrets: @@ -2289,7 +2297,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel @@ -2309,7 +2317,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel @@ -2331,7 +2339,7 @@ jobs: DESIRED_CUDA: cu121 GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_1 secrets: @@ -2352,7 +2360,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel @@ -2372,7 +2380,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel @@ -2394,7 +2402,7 @@ jobs: DESIRED_CUDA: cu124 GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda12_4 secrets: diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 5b0d9c11916e7a..db0748463da587 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -48,7 +48,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.8" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -69,7 +69,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -91,7 +91,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cpu-s390x secrets: @@ -111,7 +111,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.9" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -132,7 +132,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -154,7 +154,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x secrets: @@ -174,7 +174,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -195,7 +195,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -217,7 +217,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -237,7 +237,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -258,7 +258,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -280,7 +280,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -300,7 +300,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -321,7 +321,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -343,7 +343,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x - DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-2.4 + DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index bb6be1f4073e6f..52ccb92a1935b1 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -77,6 +77,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -88,7 +89,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -140,7 +141,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: conda-py3_8-cpu use_s3: False @@ -194,6 +195,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -205,7 +207,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -257,7 +259,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu use_s3: False @@ -311,6 +313,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -322,7 +325,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -374,7 +377,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu use_s3: False @@ -428,6 +431,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -439,7 +443,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -491,7 +495,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu use_s3: False @@ -545,6 +549,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -556,7 +561,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -608,7 +613,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/conda-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu use_s3: False diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index 85f0e6133532cd..7e2e345aefbcfb 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -81,6 +81,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -92,7 +93,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -144,7 +145,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 2705d95e1f5852..b4910d46ed5e82 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -78,6 +78,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -89,7 +90,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -141,7 +142,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.8" build_name: wheel-py3_8-cpu use_s3: False @@ -196,6 +197,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -207,7 +209,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -259,7 +261,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.9" build_name: wheel-py3_9-cpu use_s3: False @@ -314,6 +316,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -325,7 +328,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -377,7 +380,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.10" build_name: wheel-py3_10-cpu use_s3: False @@ -432,6 +435,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -443,7 +447,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -495,7 +499,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.11" build_name: wheel-py3_11-cpu use_s3: False @@ -550,6 +554,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -561,7 +566,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -613,7 +618,7 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-2.4 + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main DESIRED_PYTHON: "3.12" build_name: wheel-py3_12-cpu use_s3: False diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index 012a7d3c323e01..c3e4a038896e74 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -93,6 +93,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -104,7 +105,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -209,6 +210,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -220,7 +222,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -334,6 +336,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -345,7 +348,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -451,6 +454,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -462,7 +466,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -577,6 +581,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -588,7 +593,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -694,6 +699,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -705,7 +711,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -820,6 +826,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -831,7 +838,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -937,6 +944,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -948,7 +956,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1062,6 +1070,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1073,7 +1082,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1178,6 +1187,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1189,7 +1199,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1303,6 +1313,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1314,7 +1325,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1420,6 +1431,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1431,7 +1443,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1546,6 +1558,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1557,7 +1570,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1663,6 +1676,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1674,7 +1688,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1789,6 +1803,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1800,7 +1815,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1906,6 +1921,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1917,7 +1933,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2031,6 +2047,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2042,7 +2059,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2147,6 +2164,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2158,7 +2176,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2272,6 +2290,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2283,7 +2302,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2389,6 +2408,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2400,7 +2420,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2515,6 +2535,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2526,7 +2547,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2632,6 +2653,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2643,7 +2665,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2758,6 +2780,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2769,7 +2792,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2875,6 +2898,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2886,7 +2910,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3000,6 +3024,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3011,7 +3036,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3116,6 +3141,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3127,7 +3153,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3241,6 +3267,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3252,7 +3279,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3358,6 +3385,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3369,7 +3397,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3484,6 +3512,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3495,7 +3524,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3601,6 +3630,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3612,7 +3642,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3727,6 +3757,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3738,7 +3769,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3844,6 +3875,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3855,7 +3887,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3969,6 +4001,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3980,7 +4013,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4085,6 +4118,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4096,7 +4130,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4210,6 +4244,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4221,7 +4256,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4327,6 +4362,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4338,7 +4374,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4453,6 +4489,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4464,7 +4501,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4570,6 +4607,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4581,7 +4619,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4696,6 +4734,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4707,7 +4746,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4813,6 +4852,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4824,7 +4864,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index d8beef22c28ae9..8ac413be0d65ea 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -90,6 +90,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -101,7 +102,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -210,6 +211,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -221,7 +223,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 26cf80f685c3a9..60ba59556926f2 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -97,6 +97,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -108,7 +109,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -217,6 +218,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -228,7 +230,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -350,6 +352,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -361,7 +364,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -471,6 +474,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -482,7 +486,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -605,6 +609,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -616,7 +621,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -726,6 +731,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -737,7 +743,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -860,6 +866,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -871,7 +878,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -981,6 +988,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -992,7 +1000,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index 7624b6c69e9ec0..ab00cdc8919ea9 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -90,6 +90,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -101,7 +102,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -210,6 +211,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -221,7 +223,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 877d49691bd793..842de97a1fbe99 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -97,6 +97,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -108,7 +109,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -217,6 +218,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -228,7 +230,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -350,6 +352,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -361,7 +364,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -471,6 +474,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -482,7 +486,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -605,6 +609,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -616,7 +621,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -726,6 +731,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -737,7 +743,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -860,6 +866,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -871,7 +878,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -981,6 +988,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -992,7 +1000,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 42738a0ddb75bb..d06f99bd9a5a96 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -94,6 +94,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -105,7 +106,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -210,6 +211,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -221,7 +223,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -336,6 +338,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -347,7 +350,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -453,6 +456,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -464,7 +468,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -580,6 +584,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -591,7 +596,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -697,6 +702,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -708,7 +714,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -824,6 +830,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -835,7 +842,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -941,6 +948,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -952,7 +960,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1067,6 +1075,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1078,7 +1087,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1183,6 +1192,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1194,7 +1204,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1309,6 +1319,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1320,7 +1331,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1426,6 +1437,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1437,7 +1449,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1553,6 +1565,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1564,7 +1577,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1670,6 +1683,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1681,7 +1695,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1797,6 +1811,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1808,7 +1823,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -1914,6 +1929,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -1925,7 +1941,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2040,6 +2056,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2051,7 +2068,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2156,6 +2173,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2167,7 +2185,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2282,6 +2300,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2293,7 +2312,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2399,6 +2418,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2410,7 +2430,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2526,6 +2546,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2537,7 +2558,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2643,6 +2664,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2654,7 +2676,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2770,6 +2792,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2781,7 +2804,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -2887,6 +2910,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -2898,7 +2922,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3013,6 +3037,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3024,7 +3049,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3129,6 +3154,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3140,7 +3166,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3255,6 +3281,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3266,7 +3293,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3372,6 +3399,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3383,7 +3411,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3499,6 +3527,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3510,7 +3539,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3616,6 +3645,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3627,7 +3657,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3743,6 +3773,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3754,7 +3785,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3860,6 +3891,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3871,7 +3903,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -3986,6 +4018,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -3997,7 +4030,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4102,6 +4135,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4113,7 +4147,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4228,6 +4262,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4239,7 +4274,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4345,6 +4380,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4356,7 +4392,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4472,6 +4508,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4483,7 +4520,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4589,6 +4626,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4600,7 +4638,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4716,6 +4754,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4727,7 +4766,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder @@ -4833,6 +4872,7 @@ jobs: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} submodules: recursive path: pytorch quiet-checkout: true @@ -4844,7 +4884,7 @@ jobs: - name: Checkout pytorch/builder uses: malfet/checkout@silent-checkout with: - ref: release/2.4 + ref: main submodules: recursive repository: pytorch/builder path: builder diff --git a/.github/workflows/lint-bc.yml b/.github/workflows/lint-bc.yml index e961e4810877f7..73d7805082026b 100644 --- a/.github/workflows/lint-bc.yml +++ b/.github/workflows/lint-bc.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Run BC Lint Action - uses: pytorch/test-infra/.github/actions/bc-lint@release/2.4 + uses: pytorch/test-infra/.github/actions/bc-lint@main with: repo: ${{ github.event.pull_request.head.repo.full_name }} base_sha: ${{ github.event.pull_request.base.sha }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3080bffdcb894d..e0e4d3c20cd84c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ permissions: read-all # When any other step fails, it's job will be retried once by retryBot. jobs: lintrunner-clang: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 runner: linux.2xlarge @@ -32,7 +32,7 @@ jobs: .github/scripts/lintrunner.sh lintrunner-noclang: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: timeout: 120 runner: linux.2xlarge @@ -47,7 +47,7 @@ jobs: .github/scripts/lintrunner.sh quick-checks: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.2xlarge docker-image: pytorch-linux-focal-linter @@ -88,7 +88,7 @@ jobs: if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false fetch-depth: -1 @@ -101,7 +101,7 @@ jobs: bash .github/scripts/pr-sanity-check.sh workflow-checks: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.2xlarge docker-image: pytorch-linux-focal-linter @@ -112,7 +112,6 @@ jobs: # The generic Linux job chooses to use base env, not the one setup by the image CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") conda activate "${CONDA_ENV}" - export RELEASE_VERSION_TAG="2.4" # Regenerate workflows .github/scripts/generate_ci_workflows.py @@ -138,7 +137,7 @@ jobs: exit $RC toc: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.2xlarge docker-image: pytorch-linux-focal-linter @@ -176,7 +175,7 @@ jobs: test-tools: name: Test tools if: ${{ github.repository == 'pytorch/pytorch' }} - uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.4 + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: linux.2xlarge docker-image: pytorch-linux-focal-linter @@ -197,7 +196,7 @@ jobs: runs-on: linux.20_04.4x steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false fetch-depth: 1 @@ -227,7 +226,7 @@ jobs: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required, to allow us to use git log - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/llm_td_retrieval.yml b/.github/workflows/llm_td_retrieval.yml index d0914a76fba730..047e8ace0049d8 100644 --- a/.github/workflows/llm_td_retrieval.yml +++ b/.github/workflows/llm_td_retrieval.yml @@ -116,5 +116,5 @@ jobs: AWS_REGION: "" - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/nightly-rockset-uploads.yml b/.github/workflows/nightly-rockset-uploads.yml index e45fbee93e3ca5..18ff29f144675c 100644 --- a/.github/workflows/nightly-rockset-uploads.yml +++ b/.github/workflows/nightly-rockset-uploads.yml @@ -21,7 +21,7 @@ jobs: environment: upload-stats steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 0b868e8fe929c8..25f71c70e94861 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -41,7 +41,7 @@ jobs: environment: update-commit-hash steps: - name: update-vision-commit-hash - uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.4 + uses: pytorch/test-infra/.github/actions/update-commit-hash@main if: ${{ github.event_name == 'schedule' }} with: repo-name: vision @@ -56,7 +56,7 @@ jobs: environment: update-commit-hash steps: - name: update-audio-commit-hash - uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.4 + uses: pytorch/test-infra/.github/actions/update-commit-hash@main if: ${{ github.event_name == 'schedule' }} with: repo-name: audio @@ -71,7 +71,7 @@ jobs: environment: update-commit-hash steps: - name: update-executorch-commit-hash - uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.4 + uses: pytorch/test-infra/.github/actions/update-commit-hash@main if: ${{ github.event_name == 'schedule' }} with: repo-name: executorch diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 9cc26238035029..e8bf91c8d9ee91 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -24,7 +24,7 @@ jobs: - name: Calculate docker image id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 working-directory: pytorch @@ -39,13 +39,13 @@ jobs: echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@release/2.4 + uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver - uses: pytorch/test-infra/.github/actions/setup-nvidia@release/2.4 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main - name: Clone CodeLlama uses: actions/checkout@v3 @@ -136,7 +136,7 @@ jobs: "s3://target-determinator-assets/indexes/latest/${ZIP_NAME}" - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@release/2.4 + uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() concurrency: diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index 510a3c7f1abe59..cd5e758345b597 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -14,7 +14,7 @@ jobs: # checkout because when we run this action we don't *have* a local # checkout. In other cases you should prefer a local checkout. - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 347a9b20d71147..94a712b377484e 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -16,7 +16,7 @@ jobs: environment: ${{ (github.event_name == 'schedule') && 'mergebot' || '' }} steps: - name: Update viable/strict - uses: pytorch/test-infra/.github/actions/update-viablestrict@release/2.4 + uses: pytorch/test-infra/.github/actions/update-viablestrict@main with: repository: pytorch/pytorch stable-branch: viable/strict diff --git a/.github/workflows/update_pytorch_labels.yml b/.github/workflows/update_pytorch_labels.yml index 3c161ae48450f7..db09474fb2120d 100644 --- a/.github/workflows/update_pytorch_labels.yml +++ b/.github/workflows/update_pytorch_labels.yml @@ -17,7 +17,7 @@ jobs: contents: read steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/upload-alerts.yml b/.github/workflows/upload-alerts.yml index e16672f5e1ec57..bf370d6ef1b89a 100644 --- a/.github/workflows/upload-alerts.yml +++ b/.github/workflows/upload-alerts.yml @@ -44,7 +44,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - uses: pytorch/test-infra/.github/actions/upload-alerts@release/2.4 + uses: pytorch/test-infra/.github/actions/upload-alerts@main with: alerts: '${{ steps.alert_creation_step.outputs.script-output }}' organization: "pytorch" diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 592fe647bdb91d..3b63f686019f58 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -39,7 +39,7 @@ jobs: run: echo "${TRIGGERING_WORKFLOW}" - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - uses: actions/setup-python@v4 with: diff --git a/.github/workflows/upload-torch-dynamo-perf-stats.yml b/.github/workflows/upload-torch-dynamo-perf-stats.yml index babec6ce5f77c8..14a0f2c8cb65d6 100644 --- a/.github/workflows/upload-torch-dynamo-perf-stats.yml +++ b/.github/workflows/upload-torch-dynamo-perf-stats.yml @@ -29,7 +29,7 @@ jobs: name: Upload dynamo performance stats for ${{ github.event.workflow_run.id }}, attempt ${{ github.event.workflow_run.run_attempt }} steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: submodules: false fetch-depth: 1 diff --git a/.github/workflows/upload_test_stats_intermediate.yml b/.github/workflows/upload_test_stats_intermediate.yml index e333650c06d372..d560f619db43d3 100644 --- a/.github/workflows/upload_test_stats_intermediate.yml +++ b/.github/workflows/upload_test_stats_intermediate.yml @@ -17,7 +17,7 @@ jobs: environment: upload-stats steps: - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@release/2.4 + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 submodules: false diff --git a/.github/workflows/weekly.yml b/.github/workflows/weekly.yml index 8ebd6b1f7f797a..f097b146c21f8d 100644 --- a/.github/workflows/weekly.yml +++ b/.github/workflows/weekly.yml @@ -21,7 +21,7 @@ jobs: fetch-depth: 0 - name: update-xla-commit-hash continue-on-error: true - uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.4 + uses: pytorch/test-infra/.github/actions/update-commit-hash@main with: repo-name: xla branch: master @@ -30,7 +30,7 @@ jobs: updatebot-token: ${{ secrets.UPDATEBOT_TOKEN }} pytorchbot-token: ${{ secrets.GH_PYTORCHBOT_TOKEN }} - name: update-triton-commit-hash - uses: pytorch/test-infra/.github/actions/update-commit-hash@release/2.4 + uses: pytorch/test-infra/.github/actions/update-commit-hash@main with: repo-owner: openai repo-name: triton diff --git a/tools/stats/import_test_stats.py b/tools/stats/import_test_stats.py index 173551fe7ce474..513edb12fcfe68 100644 --- a/tools/stats/import_test_stats.py +++ b/tools/stats/import_test_stats.py @@ -77,7 +77,7 @@ def is_cached_file_valid() -> bool: def get_slow_tests( dirpath: str, filename: str = SLOW_TESTS_FILE ) -> Optional[Dict[str, float]]: - url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json?versionId=oKMp2dsjwgbtvuXJrL9fZbQiSJkiw91I" + url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json" try: return fetch_and_cache(dirpath, filename, url, lambda x: x) except Exception: @@ -117,7 +117,7 @@ def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]: return disabled_test_from_issues try: - url = "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json?versionId=0zzD6gFqZ9l2Vs1SYtjHXSuVLz6BaLtE" + url = "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json" return fetch_and_cache(dirpath, filename, url, process_disabled_test) except Exception: print("Couldn't download test skip set, leaving all tests enabled...") From 8629939a51813def63363ff3bdfe1a6e56c69e18 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Fri, 14 Jun 2024 16:01:12 +0000 Subject: [PATCH 003/171] =?UTF-8?q?[torch/c10]=20Add=20C10=5FUBSAN=5FENABL?= =?UTF-8?q?ED=20macro=20and=20use=20it=20to=20disable=20SymInt=5F=E2=80=A6?= =?UTF-8?q?=20(#127967)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `C10_UBSAN_ENABLED` macro and use it to disable `SymIntTest::Overflows` (fails under `signed-integer-overflow` UBSAN check). Also cleans up UBSAN guard in `jit/test_misc.cpp` to use `C10_UBSAN_ENABLED` and the existing `C10_ASAN_ENABLED` instead of locally defining `HAS_ASANUBSAN`. > NOTE: This should fix `SymIntTest::Overflows` failing under ubsan in fbcode too... Pull Request resolved: https://github.com/pytorch/pytorch/pull/127967 Approved by: https://github.com/atalman, https://github.com/d4l3k, https://github.com/malfet --- c10/macros/Macros.h | 19 +++++++++++++++++++ c10/test/core/SymInt_test.cpp | 4 ++++ test/cpp/jit/test_misc.cpp | 9 ++------- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index a66933823d80f7..51b54105297082 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -64,6 +64,25 @@ #define C10_ASAN_ENABLED 0 #endif +// Detect undefined-behavior sanitizer (UBSAN) +#undef C10_UBSAN_ENABLED + +// for clang or gcc >= 14 +// NB: gcc 14 adds support for Clang's __has_feature +// https://gcc.gnu.org/gcc-14/changes.html +// gcc < 14 doesn't have a macro for UBSAN +// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) +// https://github.com/google/sanitizers/issues/765 +#if defined(__has_feature) +#if ((__has_feature(undefined_behavior_sanitizer))) +#define C10_UBSAN_ENABLED 1 +#endif +#endif + +#if !defined(C10_UBSAN_ENABLED) +#define C10_UBSAN_ENABLED 0 +#endif + // Disable the copy and assignment operator for a class. Note that this will // disable the usage of the class in std containers. #define C10_DISABLE_COPY_AND_ASSIGN(classname) \ diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index df11d8be786214..8055ec7a325111 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -2,6 +2,7 @@ #include #include +#include using namespace c10; #ifndef C10_MOBILE @@ -22,6 +23,8 @@ TEST(SymIntTest, CheckRange) { EXPECT_FALSE(SymInt::check_range(INT64_MIN)); } +#if !C10_UBSAN_ENABLED +// This test fails signed-integer-overflow UBSAN check TEST(SymIntTest, Overflows) { const auto x = SymInt(INT64_MAX); EXPECT_NE(-(x + 1), 0); @@ -30,5 +33,6 @@ TEST(SymIntTest, Overflows) { EXPECT_NE(-y, 0); EXPECT_NE(0 - y, 0); } +#endif #endif diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 7f69e3b07c971e..50be38407913e6 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -491,13 +492,7 @@ TEST(ControlFlowTest, Basic) { ASSERT_EQ(256, run_binary("while_test", 2, 0)); } -#if defined(__has_feature) -#if __has_feature(address_sanitizer) -#define HAS_ASANUBSAN 1 -#endif -#endif - -#ifndef HAS_ASANUBSAN +#if !(C10_ASAN_ENABLED || C10_UBSAN_ENABLED) // This test fails vptr UBSAN checks TEST(ProtoTest, Basic) { From 732b4e90740f5fe84d40f716f911c6dcbd9845fb Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 13 Jun 2024 20:30:26 -0700 Subject: [PATCH 004/171] Fix generated vararg types (#128648) In the generated files torchgen is incorrectly generating types on the varargs. The changes all look like this (changing `*size: _int` to `*size: Union[_int, SymInt]`): ``` --- ./torch/_VF.pyi.sav 2024-06-13 20:36:49.189664629 -0700 +++ ./torch/_VF.pyi 2024-06-13 20:36:57.208894614 -0700 @@ -168,17 +168,17 @@ @overload def _efficientzerotensor(size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... @overload -def _efficientzerotensor(*size: _int, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... +def _efficientzerotensor(*size: Union[_int, SymInt], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Optional[DeviceLikeType]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ... def _embedding_bag(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... def _embedding_bag_forward_only(weight: Tensor, indices: Tensor, offsets: Tensor, scale_grad_by_freq: _bool = False, mode: _int = 0, sparse: _bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: _bool = False, padding_idx: _int = -1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @overload ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128648 Approved by: https://github.com/jamesjwu --- torchgen/api/python.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchgen/api/python.py b/torchgen/api/python.py index a67ff7c174cadf..8d3e6f3b3edd46 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -440,9 +440,13 @@ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[st if not have_vararg_version: return None + # Below are the major changes in vararg vs. regular pyi signatures # vararg signatures also omit the asterix - schema_formals[0] = "*" + args[0].name + ": _int" + assert isinstance(vararg_type, ListType) + schema_formals[0] = ( + "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) + ) returns_str = returns_str_pyi(self) # pyi also includes self (with no typing/defaults) for methods From f75f5987aa9a457e8493d910e0f2901e107a0ebf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 14 Jun 2024 16:46:03 +0000 Subject: [PATCH 005/171] Revert "Extended Module Tracker (#128508)" This reverts commit 1f46284f9ed5b60981174e689d750b358b19e4c4. Reverted https://github.com/pytorch/pytorch/pull/128508 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/9515753429/job/26230639980 ([comment](https://github.com/pytorch/pytorch/pull/128508#issuecomment-2168405784)) --- test/distributed/_tools/test_mod_tracker.py | 140 ------------ torch/distributed/_tools/__init__.py | 1 - torch/distributed/_tools/mod_tracker.py | 231 -------------------- 3 files changed, 372 deletions(-) delete mode 100644 test/distributed/_tools/test_mod_tracker.py delete mode 100644 torch/distributed/_tools/mod_tracker.py diff --git a/test/distributed/_tools/test_mod_tracker.py b/test/distributed/_tools/test_mod_tracker.py deleted file mode 100644 index 7625e8456f5a84..00000000000000 --- a/test/distributed/_tools/test_mod_tracker.py +++ /dev/null @@ -1,140 +0,0 @@ -# Owner(s): ["module: unknown"] - -from copy import copy - -import torch -from torch.distributed._tools.mod_tracker import ModTracker -from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo - - -class TestModTracker(TestCase): - # "https://github.com/pytorch/pytorch/issues/127112 - @xfailIfTorchDynamo - def test_module_hierarchy(self): - seen_fw = [] - seen_bw = [] - - class Foo(torch.nn.Module): - def forward(self, x): - x = x["a"].relu_() - seen_fw.append((copy(tracker.parents), tracker.is_bw)) - x.register_hook( - lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) - ) - return {"a": torch.mm(x, x)} - - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.a = Foo() - self.b = torch.nn.ModuleDict({"nest": Foo()}) - self.c = torch.nn.ModuleList([Foo()]) - - def forward(self, x): - x = self.c[0](x) - return self.b["nest"](self.a(x)) - - mod = Mod() - - with ModTracker() as tracker: - mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ - "a" - ].sum().backward() - mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ - "a" - ].sum().backward() - - self.assertEqual( - seen_fw, - [ - ({"Global", "Mod", "Mod.c.0"}, False), - ({"Global", "Mod", "Mod.a"}, False), - ({"Global", "Mod", "Mod.b.nest"}, False), - ({"Global", "Mod", "Mod.c.0"}, False), - ({"Global", "Mod", "Mod.a"}, False), - ({"Global", "Mod", "Mod.b.nest"}, False), - ], - ) - - self.assertEqual( - seen_bw, - [ - ({"Global", "Mod", "Mod.b.nest"}, True), - ({"Global", "Mod", "Mod.a"}, True), - ({"Global", "Mod", "Mod.c.0"}, True), - ({"Global", "Mod", "Mod.b.nest"}, True), - ({"Global", "Mod", "Mod.a"}, True), - ({"Global", "Mod", "Mod.c.0"}, True), - ], - ) - - def test_bw_detection(self): - mod = torch.nn.Linear(2, 2) - - with ModTracker() as tracker: - mod(torch.rand(2, requires_grad=True)).sum().backward() - self.assertFalse(tracker.is_bw) - self.assertEqual(tracker.parents, {"Global"}) - - @xfailIfTorchDynamo - def test_user_hooks(self): - class Bar(torch.nn.Module): - def __init__(self): - super().__init__() - self.foo = torch.nn.Linear(10, 10) - - def forward(self, x): - return self.foo(x).relu_() - - mt = ModTracker() - test_op = [] - - def hook(mod, hook_name): - mfqn = mt.get_known_fqn(mod) if mod is not None else None - test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) - - mod = Bar() - - mt.register_user_hooks( - lambda m, inp: hook(m, "pre_fw"), - lambda m, inp, op: hook(m, "post_fw"), - lambda m, gop: hook(m, "pre_bw"), - lambda m, ginp: hook(m, "post_bw"), - ) - with mt: - mod(torch.rand(10, 10, requires_grad=True)).sum().backward() - expected_op = [ - ("pre_fw", "Bar", True, False), - ("pre_fw", "Bar.foo", True, False), - ("post_fw", "Bar.foo", True, False), - ("post_fw", "Bar", True, False), - ("pre_bw", "Bar", True, True), - ("pre_bw", "Bar.foo", True, True), - ("post_bw", "Bar", True, True), - ("post_bw", "Bar.foo", True, True), - ] - self.assertEqual(test_op, expected_op) - - with self.assertRaises(AssertionError): - mt.register_user_hooks(lambda x, y: x, None, None, None) - - test_op.clear() - with mt: - loss = mod(torch.rand(10, 10, requires_grad=True)).sum() - del mod - loss.backward() - expected_op = [ - ("pre_fw", "Bar", True, False), - ("pre_fw", "Bar.foo", True, False), - ("post_fw", "Bar.foo", True, False), - ("post_fw", "Bar", True, False), - ("pre_bw", None, False, True), - ("pre_bw", None, False, True), - ("post_bw", None, False, True), - ("post_bw", None, False, True), - ] - self.assertEqual(test_op, expected_op) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/distributed/_tools/__init__.py b/torch/distributed/_tools/__init__.py index b8a8950f9cd0a0..eda274b5724f0f 100644 --- a/torch/distributed/_tools/__init__.py +++ b/torch/distributed/_tools/__init__.py @@ -1,2 +1 @@ from .memory_tracker import MemoryTracker -from .mod_tracker import ModTracker diff --git a/torch/distributed/_tools/mod_tracker.py b/torch/distributed/_tools/mod_tracker.py deleted file mode 100644 index a8f107e0cc438f..00000000000000 --- a/torch/distributed/_tools/mod_tracker.py +++ /dev/null @@ -1,231 +0,0 @@ -import warnings -import weakref -from typing import Callable, Optional, Set - -import torch -from torch.autograd.graph import register_multi_grad_hook -from torch.nn.modules.module import ( - register_module_forward_hook, - register_module_forward_pre_hook, -) -from torch.utils._pytree import tree_flatten - - -__all__ = ["ModTracker"] - - -class ModTracker: - """ - ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution - so that other system can query which Module is currently being executed (or its backward is being - executed). - - You can access the ``parents`` attribute on this context manager to get the set of all the - Modules currently being executed via their fqn (fully qualified name, also used as the key within - the state_dict). - You can access the ``is_bw`` attribute to know if you are currently running in backward or not. - - Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag - will remain ``True`` after the forward until another Module is executed. If you need it to be - more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance - is possible but not done yet, please submit an issue requesting this if you need it. - - Example usage - - .. code-block:: python - - mod = torch.nn.Linear(2, 2) - - with ModTracker() as tracker: - # Access anything during the forward pass - def my_linear(m1, m2, bias): - print(f"Current modules: {tracker.parents}") - return torch.mm(m1, m2.t()) + bias - torch.nn.functional.linear = my_linear - - mod(torch.rand(2, 2)) - - """ - - parents: Set[str] - """ - A Set containing the fqn for each module currently running their forward - """ - - def __init__(self): - self.parents = {"Global"} - self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() - self._seen_modules: weakref.WeakSet = weakref.WeakSet() - self._has_callback = False - self._user_pre_fw_hook = None - self._user_post_fw_hook = None - self._user_pre_bw_hook = None - self._user_post_bw_hook = None - - def _maybe_set_engine_callback(self): - # This assumes no concurrent calls to backward - if self._has_callback: - return - - def callback(): - self.parents = {"Global"} - self._has_callback = False - - torch.autograd.Variable._execution_engine.queue_callback(callback) - self._has_callback = True - - @property - def is_bw(self): - """ - A boolean marking if this is currently running during the backward pass or not - """ - return torch._C._current_graph_task_id() != -1 - - def get_known_fqn(self, mod): - """ - Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. - """ - return self._known_modules.get(mod, None) - - def register_user_hooks( - self, - pre_fw_hook: Optional[Callable] = None, - post_fw_hook: Optional[Callable] = None, - pre_bw_hook: Optional[Callable] = None, - post_bw_hook: Optional[Callable] = None, - ): - """ - Registers user-specified hooks to be called before/after the forward/backward pass for each - module tracked by the ``ModTracker``. One or more can be ``None``. - Args: - pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the - module. It should have the following signature: - pre_fw_hook (module, input) -> None - post_fw_hook (Callable, optional): A hook to be called after the forward pass for the - module. It should have the following signature: - post_fw_hook (module, input, output) -> None - pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of - the module that require gradients. It should have the following signature: - pre_bw_hook (module, grad_output) -> None - post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of - the module that require gradients. It should have the following signature: - post_bw_hook (module, grad_input) -> None - Raises: - AssertionError: If a new hook is provided when one is already registered. - Note: - If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will - will receive None as the module argument. - The module fqn will be present in the ``parents`` attribute when each of the hooks is called. - Hooks are intended to be used as markers only not to modify the inputs/outputs. - """ - - def set_hook(hook, user_hook, hook_name): - if hook is not None and user_hook is not None: - raise AssertionError( - f"Only one {hook_name} can be registered at a time" - f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" - ) - return hook - - self._user_pre_fw_hook = set_hook( - pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" - ) - self._user_post_fw_hook = set_hook( - post_fw_hook, self._user_post_fw_hook, "post_fw_hook" - ) - self._user_pre_bw_hook = set_hook( - pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" - ) - self._user_post_bw_hook = set_hook( - post_bw_hook, self._user_post_bw_hook, "post_bw_hook" - ) - - def clear_user_hooks(self): - """ - Clears the user specified hooks registered with ``register_user_hooks`` - """ - self._user_pre_fw_hook = None - self._user_post_fw_hook = None - self._user_pre_bw_hook = None - self._user_post_bw_hook = None - - def _get_mod_name(self, mod): - if mod not in self._known_modules: - self._known_modules[mod] = type(mod).__name__ - mod_name = self._known_modules[mod] - if mod not in self._seen_modules: - for name, submod in mod.named_children(): - self._known_modules[submod] = f"{mod_name}.{name}" - self._get_mod_name(submod) - self._seen_modules.add(mod) - return mod_name - - def _get_append_fn(self, w_mod, name, is_bw): - def fn(*args): - if is_bw: - self._maybe_set_engine_callback() - if name in self.parents and not self.is_bw: - - def custom_formatwarning(msg, category, filename, lineno, line=None): - return f"{filename}:{lineno}: {category.__name__}: {msg} \n" - - warnings.formatwarning = custom_formatwarning - warnings.warn( - "The module hierarchy tracking maybe be messed up." - " Please file a bug to PyTorch, if it is the case." - ) - self.parents.add(name) - - if self._user_pre_bw_hook is not None and is_bw: - self._user_pre_bw_hook(w_mod(), args) - - return fn - - def _get_pop_fn(self, w_mod, name, is_bw): - def fn(*args): - if self._user_post_bw_hook is not None and is_bw: - self._user_post_bw_hook(w_mod(), args) - - if name in self.parents: - self.parents.remove(name) - elif not is_bw: - # Due to some input/output not requiring gradients, we cannot enforce - # proper nesting in backward - raise RuntimeError( - "The Module hierarchy tracking is wrong. Report a bug to PyTorch" - ) - - return fn - - def _fw_pre_hook(self, mod, input): - name = self._get_mod_name(mod) - w_mod = weakref.ref(mod) - self._get_append_fn(w_mod, name, False)() - if self._user_pre_fw_hook is not None: - self._user_pre_fw_hook(mod, input) - args, _ = tree_flatten(input) - tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] - if not self.is_bw and tensors: - register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) - - def _fw_post_hook(self, mod, input, output): - name = self._get_mod_name(mod) - w_mod = weakref.ref(mod) - if self._user_post_fw_hook is not None: - self._user_post_fw_hook(mod, input, output) - self._get_pop_fn(w_mod, name, False)() - args, _ = tree_flatten(output) - tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] - if not self.is_bw and tensors: - register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True)) - - def __enter__(self): - self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) - self._fw_post_handle = register_module_forward_hook( - self._fw_post_hook, always_call=True - ) - return self - - def __exit__(self, *args): - self._fw_pre_handle.remove() - self._fw_post_handle.remove() From 4c84af0f5d34a5463462e9b2768509e92c0b429a Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 13 Jun 2024 22:12:23 -0700 Subject: [PATCH 006/171] Fix indexing and slicing of ranges in dynamo (#128567) Fix https://github.com/pytorch/pytorch/issues/128520 Dynamo does not handle range()[binary subscript] or range()[trinary_subscript] correctly. Right now it calls the get_item function which basically applies the subscript operation on top of the list of [start, end, step]! which is completely not related to what is expected. in python, range()[complex subscript] is another range, ex: range(1, 10, 2)[1:4:1] is range(3, 9, 2) and range(1, 10, 2)[1:4:1] is range(-9, 9, 2) This diff fix index and slice applications on range. it mimics implementations from (https://github.com/python/cpython/blob/main/Objects/rangeobject.c) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128567 Approved by: https://github.com/anijain2305 --- test/dynamo/test_functions.py | 153 ++++++++++++++++++ ....test_simple_snapshot_custom_non_generator | 0 ...ward.test_simple_snapshot_custom_self_next | 0 torch/_dynamo/variables/lists.py | 114 ++++++++++++- 4 files changed, 266 insertions(+), 1 deletion(-) delete mode 100644 test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_non_generator delete mode 100644 test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_self_next diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index e2baebf60321bc..a8650ba1adcc15 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -27,6 +27,8 @@ normalize_gm, ) from torch._dynamo.utils import ifdynstaticdefault, same +from torch._dynamo.variables import ConstantVariable +from torch._dynamo.variables.lists import RangeVariable from torch.nn import functional as F from torch.testing._internal.common_utils import ( @@ -2369,6 +2371,157 @@ def fn(): opt_fn = torch._dynamo.optimize(nopython=True)(fn) self.assertEqual(opt_fn(), fn()) + def gen_random_range_args(self): + args_count = random.randint(1, 3) + args = [random.randint(-10, 10) for _ in range(args_count)] + if args_count == 3 and args[2] == 0: + args[2] = 1 + return args + + def test_range_length(self): + def test(*args, expected=None): + r = range(*args) + range_variable = RangeVariable([ConstantVariable.create(v) for v in args]) + + self.assertEqual(len(r), range_variable.range_length()) + + if expected is not None: + self.assertEqual(len(r), expected) + + test(1, 1, 1, expected=0) + test(1, 0, expected=0) + test(-10, expected=0) + + test(4, expected=4) + test(10, expected=10) + + # step >1 + test(1, 10, 2, expected=5) + + # negative step + test(10, 1, -1, expected=9) + test(10, 1, -3) + + # Fuzz testing + for i in range(100): + args = self.gen_random_range_args() + print("testing :", args) + test(*args) + + def test_indexed_range(self): + def test(range, index, expected=None): + range_variable = RangeVariable( + [ + ConstantVariable.create(v) + for v in [range.start, range.stop, range.step] + ] + ) + + self.assertEqual( + range[index], + range_variable.apply_index(index).as_python_constant(), + ) + + if expected is not None: + self.assertEqual(range[index], expected) + + test(range(10), 1, expected=1) + test(range(10, 20, 2), 1, expected=12) + + # Fuzz testing + for i in range(100): + range_args = self.gen_random_range_args() + r = range(*range_args) + + if len(r) == 0: + continue + + index = random.randint(0, len(r) - 1) + + print("testing:", r, index) + test(r, index) + + def test_sliced_range(self): + def test(range, slice, expected=None): + range_variable = RangeVariable( + [ + ConstantVariable.create(v) + for v in [range.start, range.stop, range.step] + ] + ) + + self.assertEqual( + range[slice], + range_variable.apply_slice(slice).as_python_constant(), + ) + + if expected is not None: + self.assertEqual( + range[slice], + expected, + ) + + test(range(10), slice(1, 10, 2), expected=range(1, 10, 2)) + test(range(10), slice(None, 10, None), expected=range(0, 10)) + test(range(10), slice(-1, 7, None), expected=range(9, 7)) + test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2)) + test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4)) + test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4)) + test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9)) + + def rand_slice(): + def flip_coin(): + # 1 out of 10 + return random.randint(1, 10) == 5 + + def r_item(allow_zero=True): + i = random.randint(-10, 10) + if not allow_zero and i == 0: + i = 1 + if flip_coin(): + i = None + return i + + arg_count = random.randint(1, 3) + + if arg_count == 1: + return slice(r_item()) + elif arg_count == 2: + return slice(r_item(), r_item()) + else: + return slice(r_item(), r_item(), r_item(False)) + + # Fuzz testing + for i in range(100): + range_args = self.gen_random_range_args() + r = range(*range_args) + # generate random slice + s = rand_slice() + + print("testing:", r, s) + test(r, s) + + def test_range_with_slice_index(self): + def fn(x): + acc = 1 + for k in range(2)[1::2]: + acc *= acc * k + return x * acc + + opt_fn = torch.compile(fullgraph=True)(fn) + x = torch.ones(1) + self.assertEqual(opt_fn(x), fn(x)) + + def test_range_with_index(self): + def fn(x): + acc = 1 + acc *= acc * range(10, 20, 2)[2] + return x * acc + + opt_fn = torch.compile(fullgraph=True)(fn) + x = torch.ones(1) + self.assertEqual(opt_fn(x), fn(x)) + def test_rand_inlined(self): @torch.compile(backend="eager", dynamic=True) def fn(): diff --git a/test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_non_generator b/test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_non_generator deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_self_next b/test/dynamo_expected_failures/TestIterDataPipeGraphFastForward.test_simple_snapshot_custom_self_next deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index e0fe96dfa336ad..450e2e937f71a7 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -9,6 +9,7 @@ import torch import torch.fx + from ..._guards import Source from .. import polyfill, variables @@ -176,13 +177,124 @@ def debug_repr(self): def python_type(self): return range + def start(self): + return self.items[0].as_python_constant() + + def stop(self): + return self.items[1].as_python_constant() + + def step(self): + return self.items[2].as_python_constant() + + def range_length(self): + lo = self.start() + hi = self.stop() + step = self.step() + + assert step != 0 + if step > 0 and lo < hi: + return 1 + (hi - 1 - lo) // step + elif step < 0 and lo > hi: + return 1 + (lo - 1 - hi) // (0 - step) + else: + return 0 + + def _get_slice_indices(self, length, slice): + step_is_negative = 0 + + if slice.step is None: + step = 1 + step_is_negative = False + else: + step = slice.step + step_is_negative = slice.step < 0 + + # Find lower and upper bounds for start and stop. + if step_is_negative: + lower = -1 + upper = length + lower + else: + lower = 0 + upper = length + + # Compute start + if slice.start is None: + start = upper if step_is_negative else lower + else: + start = slice.start + + if start < 0: + start += length + if start < lower: + start = lower + else: + if start > upper: + start = upper + + # Compute stop. + if slice.stop is None: + stop = lower if step_is_negative else upper + + else: + stop = slice.stop + + if stop < 0: + stop += length + if stop < lower: + stop = lower + else: + if stop > upper: + stop = upper + + return [start, stop, step] + + def apply_index(self, index): + length = self.range_length() + if index < 0: + index = length + index + + if index < 0 or index >= length: + raise IndexError(f"index {index} is out of range") + + return variables.ConstantVariable.create(self.start() + (index * self.step())) + + def apply_slice(self, slice): + (slice_start, slice_stop, slice_step) = self._get_slice_indices( + self.range_length(), slice + ) + + def compute_item(index): + return self.start() + (index * self.step()) + + sub_step = self.step() * slice_step + sub_start = compute_item(slice_start) + sub_stop = compute_item(slice_stop) + + result = RangeVariable( + [ + variables.ConstantVariable.create(x) + for x in [sub_start, sub_stop, sub_step] + ], + mutable_local=MutableLocal() if self.mutable_local else None, + ) + return result + def as_python_constant(self): return range(*[x.as_python_constant() for x in self.items]) + def getitem_const(self, arg: VariableTracker): + # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c + index = arg.as_python_constant() + + if isinstance(index, slice): + return self.apply_slice(index) + else: + return self.apply_index(index) + def as_proxy(self): return self.python_type()(*self._as_proxy()) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx=None): return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] def reconstruct(self, codegen): From 1fb4effe7ad0ed92ff088aaaad62520f9fe6f0df Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 14 Jun 2024 17:03:22 +0000 Subject: [PATCH 007/171] [GPT-fast benchmark] Add MLP, gather + gemv, gemv micro benchmark (#128002) Output example: ``` | name | metric | target | actual | |------------------------------|---------------------------|---------|---------| | layer_norm_bfloat16 | memory_bandwidth(GB/s) | 1017 | 1000.01 | | mlp_layer_norm_gelu_bfloat16 | flops_utilization | 0.71 | 0.71 | | gemv_int8 | memory_bandwidth(GB/s) | 990 | 984.06 | | gemv_bfloat16 | memory_bandwidth(GB/s) | 1137 | 1137.92 | | gather_gemv_int8 | memory_bandwidth(GB/s) | 1113 | 1111.09 | | gather_gemv_bfloat16 | memory_bandwidth(GB/s) | 1249 | 1248.15 | ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128002 Approved by: https://github.com/Chillee --- benchmarks/gpt_fast/benchmark.py | 208 +++++++++++++++++++++++++------ 1 file changed, 169 insertions(+), 39 deletions(-) diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 6e335ee3129224..998878fa969284 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -2,12 +2,17 @@ import csv import dataclasses import os -import time from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8 +from triton.testing import do_bench import torch import torch.nn as nn +from torch.utils.flop_counter import FlopCounterMode + +WARMUP_ITER = 5 + +A100_80G_BF16_TFLOPS = 312 @dataclasses.dataclass @@ -18,57 +23,179 @@ class Experiment: actual: float -def do_inference(mod, x, num_samples: int = 5): - total_time = 0 - start = -1 +class SimpleMLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, dtype): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Linear(input_dim, hidden_dim, dtype=dtype), + nn.LayerNorm(hidden_dim, dtype=dtype), + nn.Linear(hidden_dim, output_dim, dtype=dtype), + nn.LayerNorm(output_dim, dtype=dtype), + ] + ) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def run_mlp_layer_norm_gelu(): + dtype_flops_utilization_map = { + torch.bfloat16: "0.71", + } + input_shapes = [1024, 4096, 8192, 16384] + intermediate_size = 14336 + results = [] + for dtype, expected_flops_utilization in dtype_flops_utilization_map.items(): + flops_utilization = 0 + for D in input_shapes: + mod = SimpleMLP( + input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype + ).to("cuda") + + x = torch.randn(D, device="cuda", dtype=torch.bfloat16) + + with FlopCounterMode(display=False) as mode: + mod(x) + + flops = mode.get_total_flops() + + compiled_mod = torch.compile(mod, dynamic=False) + + for _ in range(WARMUP_ITER): + compiled_mod(x) + + us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000 + flops_utilization += us_per_iter * flops / 1e9 / A100_80G_BF16_TFLOPS + + flops_utilization = flops_utilization / len(input_shapes) + dtype_str = str(dtype).replace("torch.", "") + results.append( + Experiment( + f"mlp_layer_norm_gelu_{dtype_str}", + "flops_utilization", + expected_flops_utilization, + f"{flops_utilization:.02f}", + ) + ) + return results + + +def run_layer_norm(): + dtype_memory_bandwidth_map = { + torch.bfloat16: "1017", + } + input_shapes = [1024, 4096, 8192, 16384] + BS = 4096 + results = [] + for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): + memory_bandwidth = 0 + for D in input_shapes: + mod = nn.LayerNorm(D).to("cuda") + + x = torch.randn(BS, D, device="cuda", dtype=dtype) - for i in range(start, num_samples): - torch.cuda.synchronize("cuda") + compiled_mod = torch.compile(mod, dynamic=False) - t0 = time.perf_counter() - mod(x) + for _ in range(WARMUP_ITER): + compiled_mod(x) - if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") - continue + us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000 + memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9 + + memory_bandwidth = memory_bandwidth / len(input_shapes) + dtype_str = str(dtype).replace("torch.", "") + results.append( + Experiment( + f"layer_norm_{dtype_str}", + "memory_bandwidth(GB/s)", + expected_memory_bandwidth, + f"{memory_bandwidth:.02f}", + ) + ) + return results + + +@torch._inductor.config.patch(coordinate_descent_tuning=True) +def run_gather_gemv(): + E = 8 + dtype_memory_bandwidth_map = { + torch.int8: "1113", + torch.bfloat16: "1249", + } + input_shapes = [1024, 4096, 8192, 16384] + results = [] + for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): + memory_bandwidth = 0 + for D in input_shapes: - torch.cuda.synchronize("cuda") - total_time += time.perf_counter() - t0 + def gather_gemv(W, score_idxs, x): + return W[score_idxs].to(x.dtype) @ x - total_time = total_time / num_samples + W = torch.randn(E, D, D, device="cuda").to(dtype=dtype) + x = torch.randn(D, device="cuda", dtype=torch.bfloat16) + score_idxs = torch.tensor([3, 5], device="cuda") - return total_time + compiled_fn = torch.compile(gather_gemv, dynamic=False) + for _ in range(WARMUP_ITER): + compiled_fn(W, score_idxs, x) -def run_multi_layer_norm(): - class MultiLayerNorm(nn.Module): - def __init__(self, num_layers, normalized_shape, eps=1e-5, bias=True): - super().__init__() - self.num_layers = num_layers - self.norm_layers = nn.ModuleList( - [ - nn.LayerNorm(normalized_shape, eps=eps, bias=bias) - for _ in range(num_layers) - ] + us_per_iter = do_bench(lambda: compiled_fn(W, score_idxs, x)) * 1000 + memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9 + + memory_bandwidth = memory_bandwidth / len(input_shapes) + dtype_str = str(dtype).replace("torch.", "") + results.append( + Experiment( + f"gather_gemv_{dtype_str}", + "memory_bandwidth(GB/s)", + expected_memory_bandwidth, + f"{memory_bandwidth:.02f}", ) + ) + return results + + +@torch._inductor.config.patch(coordinate_descent_tuning=True) +def run_gemv(): + dtype_memory_bandwidth_map = { + torch.int8: "990", + torch.bfloat16: "1137", + } + input_shapes = [1024, 4096, 8192, 16384] + results = [] + for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): + memory_bandwidth = 0 + for D in input_shapes: + + def gemv(W, x): + return W.to(x.dtype) @ x - def forward(self, x): - for layer_norm in self.norm_layers: - x = layer_norm(x) - return x + W = torch.randn(D, D, device="cuda").to(dtype=dtype) + x = torch.randn(D, device="cuda", dtype=torch.bfloat16) - mod = MultiLayerNorm(num_layers=8, normalized_shape=4096).to("cuda") - mod = torch.compile(mod) - input = torch.randn([512, 1024, 4096], dtype=torch.bfloat16, device="cuda") - inference_time = do_inference(mod, input) + compiled_fn = torch.compile(gemv, dynamic=False) - memory_bandwidth = input.numel() * input.dtype.itemsize / inference_time / 1e9 + for _ in range(WARMUP_ITER): + compiled_fn(W, x) - return [ - Experiment( - "multi_layer_norm", "memory_bandwidth(GB/s)", 92, f"{memory_bandwidth:.02f}" + us_per_iter = do_bench(lambda: compiled_fn(W, x)) * 1000 + memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9 + + memory_bandwidth = memory_bandwidth / len(input_shapes) + dtype_str = str(dtype).replace("torch.", "") + results.append( + Experiment( + f"gemv_{dtype_str}", + "memory_bandwidth(GB/s)", + expected_memory_bandwidth, + f"{memory_bandwidth:.02f}", + ) ) - ] + return results def output_csv(output_file, headers, row): @@ -100,7 +227,10 @@ def output_csv(output_file, headers, row): run_llama2_7b_int8, run_mixtral_8x7b_int8, # A list of micro-benchmarks. - run_multi_layer_norm, + run_mlp_layer_norm_gelu, + run_layer_norm, + run_gather_gemv, + run_gemv, } From 089e76cca353fe5f90ccdf3a546315530d5caecb Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Fri, 14 Jun 2024 17:05:15 +0000 Subject: [PATCH 008/171] [traced-graph][sparse] remove redundant assert in sparse prop test (#128523) The assertEqualMeta() method already tests that the first argument is a FakeTensor https://github.com/pytorch/pytorch/issues/117188 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128523 Approved by: https://github.com/huydhn --- test/export/test_sparse.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/export/test_sparse.py b/test/export/test_sparse.py index e0ed8419a9a180..dc291970648f50 100644 --- a/test/export/test_sparse.py +++ b/test/export/test_sparse.py @@ -172,7 +172,6 @@ def test_sumnet(self, dtype, itype, layout): if i == 0: self.assertEqualMeta(meta, sparse_input) elif i == 1: - self.assertIsInstance(meta, FakeTensor) self.assertEqualMeta(meta, result) else: self.assertEqual(meta, None) @@ -218,7 +217,6 @@ def test_activation_coo(self): for i, node in enumerate(prog.graph.nodes): meta = node.meta.get("val", None) if i <= 2: - self.assertIsInstance(meta, FakeTensor) self.assertEqualMeta(meta, x[i]) elif i <= 5: self.assertEqualMeta(meta, result[i - 3]) From d4807da802cf820df23860ea8de03056216c765e Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 14 Jun 2024 17:31:21 +0000 Subject: [PATCH 009/171] Various fixes of torch/csrc files (#127252) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127252 Approved by: https://github.com/r-barnes --- torch/csrc/serialization.cpp | 15 +++++++-------- torch/csrc/utils/tensor_dtypes.cpp | 3 +-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index 5a6cab19835674..10ba23a656f08c 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -268,32 +269,30 @@ void THPStorage_writeFileRaw( doWrite(fd, data, size_bytes); } else { size_t buffer_size = std::min(numel, (size_t)5000); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - std::unique_ptr le_buffer( - new uint8_t[buffer_size * element_size]); + std::vector le_buffer; + le_buffer.resize(buffer_size * element_size); for (size_t i = 0; i < numel; i += buffer_size) { size_t to_convert = std::min(numel - i, buffer_size); - // NOLINTNEXTLINE(bugprone-branch-clone) if (element_size == 2) { torch::utils::THP_encodeInt16Buffer( - (uint8_t*)le_buffer.get(), + le_buffer.data(), (const int16_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 4) { torch::utils::THP_encodeInt32Buffer( - (uint8_t*)le_buffer.get(), + le_buffer.data(), (const int32_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 8) { torch::utils::THP_encodeInt64Buffer( - (uint8_t*)le_buffer.get(), + le_buffer.data(), (const int64_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } - doWrite(fd, le_buffer.get(), to_convert * element_size); + doWrite(fd, le_buffer.data(), to_convert * element_size); } } } diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 5290392d900f74..e00e41090dfad1 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -99,8 +99,7 @@ void initializeDtypes() { #define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n, - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - at::ScalarType all_scalar_types[] = { + auto all_scalar_types = { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; for (at::ScalarType scalarType : all_scalar_types) { From 2357490524de3ef20a1c8791760924eb5e073477 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Fri, 14 Jun 2024 17:49:24 +0000 Subject: [PATCH 010/171] [PT2] Enable shape_padding multiplier adjustment (#128346) Summary: Our experiments demonstrate that the current defautl value 1.1 may not be the best multiplier, and we thus enable the adjustment of the value to further improve the QPS. context: https://docs.google.com/document/d/10VjpOJkTv5A4sNX7dD6qT7PyhBxn6LSeLAuaqYtoOto/edit Test Plan: # IG_CTR {F1682138315} Differential Revision: D58373261 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128346 Approved by: https://github.com/jackiexu1992 --- torch/_inductor/fx_passes/pad_mm.py | 13 ++++++++++++- torch/_inductor/fx_passes/split_cat.py | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index f7b7977bffc1a5..1095bf6b9214aa 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -8,9 +8,11 @@ import torch import torch._inductor.runtime.runtime_utils from torch import Tensor +from torch._dynamo.utils import counters from torch._inductor import utils from torch._subclasses.fake_tensor import FakeTensor from torch.utils._mode_utils import no_dispatch + from ...utils._triton import has_triton from ..pattern_matcher import ( @@ -488,7 +490,16 @@ def write_pad(): # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable # tradeoff between performance improvement from shape padding and overhead from additional memory ops # TODO: Build a learned model which would be better than this heuristic - should_pad = _skip_do_bench_times or ori_time > pad_time * 1.1 + multiplier = 1.1 + if ( + "shape_padding_multiplier" + in torch._inductor.config.post_grad_fusion_options + ): + multiplier = torch._inductor.config.post_grad_fusion_options[ + "shape_padding_multiplier" + ].get("value", 1.1) + counters["inductor"]["shape_padding_multiplier"] += 1 + should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier set_cached_should_pad(key, should_pad) return should_pad diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 8a2c571ee612f3..71acbf5cb6739d 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -58,6 +58,7 @@ "normalization_aten_pass", "decompose_mm_pass", "unbind_stack_aten_pass", + "shape_padding_multiplier", ] for pass_name in pre_grad_pass_names: From d7fc871175a7eefb7c348934fe27d6daa710497d Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 14 Jun 2024 01:23:07 +0100 Subject: [PATCH 011/171] [inductor] Improve superfluous mask handling in triton codegen (#128518) This takes the logic from `filter_masks` and factors it out into `_has_constant_mask`. I also improve support for `persistent_reduction` kernels by making use of the static RBLOCK value and potentially XBLOCK too in the `no_x_dim` case. I then use this helper when generating the `xmask` and `rmask`, so we can generate them as constants meaning triton can optimize them even if they are included. e.g. `compiled_sum(torch.randn(1024, 512, device="cuda"), dim=-1)` before: ```python @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, rnumel): xnumel = 1024 XBLOCK: tl.constexpr = 1 rnumel = 512 RBLOCK: tl.constexpr = 512 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0) tmp1 = tl.broadcast_to(tmp0, [RBLOCK]) tmp3 = tl.where(rmask & xmask, tmp1, 0) tmp4 = triton_helpers.promote_to_tensor(tl.sum(tmp3, 0)) tl.store(out_ptr0 + (x0), tmp4, xmask) ``` after: ```python @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, rnumel): xnumel = 1024 XBLOCK: tl.constexpr = 1 rnumel = 512 RBLOCK: tl.constexpr = 512 xoffset = tl.program_id(0) * XBLOCK xindex = tl.full([1], xoffset, tl.int32) xmask = tl.full([RBLOCK], True, tl.int1) rindex = tl.arange(0, RBLOCK)[:] roffset = 0 rmask = tl.full([RBLOCK], True, tl.int1) r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), None) tmp1 = tl.broadcast_to(tmp0, [RBLOCK]) tmp3 = triton_helpers.promote_to_tensor(tl.sum(tmp1, 0)) tl.store(out_ptr0 + (x0), tmp3, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128518 Approved by: https://github.com/lezcano --- torch/_inductor/codegen/simd.py | 20 +-------- torch/_inductor/codegen/triton.py | 67 ++++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 34 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index f2628ae9c06005..c73e81f904eb37 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -38,7 +38,7 @@ from ..dependencies import Dep, MemoryDep, StarDep, WeakDep from ..ir import TritonTemplateBuffer from ..optimize_indexing import indexing_dtype_strength_reduction -from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK +from ..runtime.hints import ReductionHint from ..runtime.runtime_utils import green_text, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( @@ -685,24 +685,6 @@ def active_range_trees(self, reorder=False): trees[:count] = reversed(trees[:count]) return trees - def filter_masks(self, mask_vars): - for tree in self.range_trees: - # Masks are superfluous if we only have one element - if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - continue - # Masks are superfluous if numel is a multiple of BLOCK - # (We use the fact that BLOCK is required by triton to be a power of 2) - if tree.prefix.upper() not in TRITON_MAX_BLOCK: - continue - max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] - # Optional optimization: if block divides numel exactly, we will - # never need to do a masked load to handle stragglers at the end. - # It's faster to avoid masking at all. But it is sound to always - # mask. - if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): # type: ignore[arg-type] - mask_vars.discard(f"{tree.prefix}mask") - def codegen_indexing(self, expr: sympy.Expr): expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) for sym in sorted(expr.free_symbols, key=str): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 254e254ef0e8c9..eab10e2496ad47 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -61,7 +61,14 @@ SizeArg, TensorArg, ) -from .simd import constant_repr, IterationRangesEntry, pexpr, SIMDKernel, SIMDScheduling +from .simd import ( + constant_repr, + IterationRangesEntry, + IterationRangesRoot, + pexpr, + SIMDKernel, + SIMDScheduling, +) from .triton_utils import config_of, signature_of, signature_to_meta if TYPE_CHECKING: @@ -2266,6 +2273,18 @@ def codegen_kernel(self, name=None): return code.getvalue() + def _get_persistent_RBLOCK(self, rnumel): + rnumel = V.graph.sizevars.simplify(rnumel) + if isinstance(rnumel, (sympy.Integer, int)): + val = int(rnumel) + val = next_power_of_2(val) + else: + val = 128 + while not V.graph.sizevars.statically_known_leq(rnumel, val): + assert val <= 16 * 1024, f"Failed to find static RBLOCK for {rnumel}" + val *= 2 + return val + def codegen_static_numels(self, code): """ We get a small speedup from hard coding numels if they are static. @@ -2290,19 +2309,7 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") if tree.prefix == "r" and self.persistent_reduction: - simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) - if isinstance(simplified_tree_numel, (sympy.Integer, int)): - val = int(simplified_tree_numel) - val = next_power_of_2(val) - else: - val = 128 - while not V.graph.sizevars.statically_known_leq( - simplified_tree_numel, val - ): - assert ( - val <= 16 * 1024 - ), f"Failed to find static RBLOCK for {simplified_tree_numel}" - val *= 2 + val = self._get_persistent_RBLOCK(tree.numel) code.writeline(f"RBLOCK: tl.constexpr = {val}") if tree.prefix == "x" and self.no_x_dim: @@ -2425,6 +2432,31 @@ def iteration_ranges_get_pid(self, entry): return f"{pid}.to({self.index_dtype})" return pid + def _has_constant_mask(self, tree: IterationRangesRoot): + if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] + return True + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.prefix == "r" and self.persistent_reduction: + max_block = self._get_persistent_RBLOCK(tree.numel) + elif tree.prefix == "x" and self.no_x_dim: + max_block = 1 + else: + if tree.prefix.upper() not in TRITON_MAX_BLOCK: + return False + max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] + + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # It's faster to avoid masking at all. But it is sound to always + # mask. + return V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block) + + def filter_masks(self, mask_vars): + for tree in self.range_trees: + if self._has_constant_mask(tree): + mask_vars.discard(f"{tree.prefix}mask") + def iteration_ranges_codegen_header(self, entry, code): x = entry.prefix if entry.is_loop: @@ -2444,7 +2476,12 @@ def iteration_ranges_codegen_header(self, entry, code): f"{entry.name} = {line}", ] ) - code.writeline(f"{x}mask = {entry.name} < {x}numel") + + if self._has_constant_mask(entry): + sizes = self.dense_size_str() + code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") + else: + code.writeline(f"{x}mask = {entry.name} < {x}numel") class TritonScheduling(SIMDScheduling): From 2367161e4bbefd45ebd507fdfc523bbf628d32a3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 14 Jun 2024 17:57:23 +0000 Subject: [PATCH 012/171] Revert "[ROCm] Unskip scaled_dot_product_attention tests on ROCm (#127966)" This reverts commit c339efaf023b4af056dad4cb2f11c07930ed8af6. Reverted https://github.com/pytorch/pytorch/pull/127966 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI ([comment](https://github.com/pytorch/pytorch/pull/127966#issuecomment-2168505985)) --- test/functorch/test_ops.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 30f2a803ffa994..44ff51a7d680d0 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -351,6 +351,8 @@ def is_inplace(op, variant): vjp_fail = { xfail("tensor_split"), # data_ptr composite compliance + # https://github.com/pytorch/pytorch/issues/96560 + decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm), } aliasing_ops = { @@ -432,10 +434,7 @@ class TestOperators(TestCase): xfail("view_as_complex"), # query: last dimension must be contiguous # Fused attention kernels require last dim to be contiguous - decorate( - "nn.functional.scaled_dot_product_attention", - decorator=expectedFailureIf(not TEST_WITH_ROCM), - ), # Works on ROCm + xfail("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), # RuntimeError: Expected contiguous tensor, but got @@ -733,10 +732,7 @@ def maybe_clone_inputs(): xfail("view_as_complex"), # RuntimeError: query: last dimension must be contiguous # The fused attention kernels require the last dim to be contiguous - decorate( - "nn.functional.scaled_dot_product_attention", - decorator=expectedFailureIf(not TEST_WITH_ROCM), - ), # Works on ROCm + xfail("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._flash_attention_forward"), xfail("torch.ops.aten._efficient_attention_forward"), # BUG @@ -994,6 +990,8 @@ def fn(inp, *args, **kwargs): xfail("normal"), # calls random op xfail("normal", "number_mean"), # calls random op xfail("pca_lowrank"), # calls random op + # https://github.com/pytorch/pytorch/issues/96560 + decorate("linalg.pinv", "hermitian", decorator=skipIfRocm), xfail( "quantile", device_type="cpu" ), # Batching rule not implemented for `at::equal` From be0eec9031caddd0860c6b9f8d70efaf91e5a7f6 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 14 Jun 2024 17:57:35 +0000 Subject: [PATCH 013/171] [export] Improve static typing in tracer. (#128552) Summary: as title. Test Plan: CI Differential Revision: D58485487 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128552 Approved by: https://github.com/angelayi --- torch/_export/serde/serialize.py | 6 +-- torch/export/_trace.py | 72 +++++++++++++++++++------------- torch/export/exported_program.py | 5 ++- 3 files changed, 49 insertions(+), 34 deletions(-) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index ff729ddb3c5ca4..206c8b88cce42c 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -27,8 +27,8 @@ Optional, Set, Tuple, - Union, Type, + Union, ) import sympy @@ -37,6 +37,7 @@ import torch.export.exported_program as ep from torch._export.serde.schema import SchemaVersion from torch._export.verifier import load_verifier +from torch._library.fake_class_registry import FakeScriptObject from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.fx.experimental import symbolic_shapes from torch.utils import _pytree as pytree @@ -1406,7 +1407,7 @@ class Result: module_call_graph: List[ep.ModuleCallEntry] names_to_symbols: Dict[str, sympy.Symbol] state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] - constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]] + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]] example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]] def __init__(self): @@ -2247,7 +2248,6 @@ def deserialize( symbol_name_to_range, res.names_to_symbols, ) - model_opset_version: Optional[Dict[str, int]] = exported_program.opset_version return ep.ExportedProgram( root=res.graph_module, diff --git a/torch/export/_trace.py b/torch/export/_trace.py index ee25dbc2e1eade..1542fb1423174b 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -95,7 +95,7 @@ class ExportDynamoConfig: @dataclasses.dataclass -class ExportedArtifact: +class ATenExportArtifact: gm: torch.fx.GraphModule sig: ExportGraphSignature constants: Dict[ @@ -106,9 +106,14 @@ class ExportedArtifact: torch.ScriptObject, ], ] - out_spec: Optional[TreeSpec] = None # type: ignore[Incompatible types in assignment] - fake_mode: Optional[FakeTensorMode] = None # type: ignore[Incompatible types in assignment] - module_call_specs: Optional[Dict[str, Dict[str, pytree.TreeSpec]]] = None # type: ignore[Incompatible types in assignment] + + +@dataclasses.dataclass(frozen=True) +class ExportArtifact: + aten: ATenExportArtifact + out_spec: TreeSpec + fake_mode: FakeTensorMode + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() @@ -342,7 +347,7 @@ def _get_param_buffer_mapping( def _remap_constants( orig_constant_attrs: ConstantAttrMap, graph_signature: ExportGraphSignature, - constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], ) -> None: """Rewrite the graph signature and constants table to use the FQN from the original module.""" remap_table: Dict[str, List[str]] = {} @@ -552,7 +557,7 @@ def _export_to_aten_ir( transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, _is_torch_jit_trace=False, -): +) -> ATenExportArtifact: # [NOTE] If the user is exporting under training mode, we want to detect if there is any # state change in the autograd global state and error. If the user is exporting under inference # mode, we don't care. At predispatch level, we don't care about the state change. @@ -719,7 +724,7 @@ def make_argument_spec(i, node) -> ArgumentSpec: constants, ) - return ExportedArtifact( + return ATenExportArtifact( gm, export_graph_signature, constants, @@ -771,7 +776,7 @@ def _rewrite_dynamo_tensor_constants( orig_mod_buffers: Set[torch.Tensor], traced_mod_buffers: Dict[str, torch.Tensor], graph_signature: ExportGraphSignature, - constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], ): """Dynamo erroneously marks tensor attributes on modules as a buffers. @@ -786,13 +791,13 @@ def _rewrite_dynamo_tensor_constants( # Convert it int oa constant in the graph signature, and add its # value to the constants table. spec.kind = InputKind.CONSTANT_TENSOR - constants[spec.target] = value + constants[spec.target] = value # type: ignore[arg-type] def _rewrite_non_persistent_buffers( orig_mod: torch.nn.Module, graph_signature: ExportGraphSignature, - constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], ): """Dynamo erroneously drops the persistent flag on buffers. @@ -805,7 +810,7 @@ def _rewrite_non_persistent_buffers( if spec.target not in state_dict: assert spec.target not in constants spec.persistent = False - constants[spec.target] = orig_mod.get_buffer(spec.target) + constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type] def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None: @@ -1056,7 +1061,7 @@ def _strict_export( _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, -): +) -> ExportArtifact: gm_torch_level = _export_to_torch_ir( mod, args, @@ -1134,6 +1139,7 @@ def _strict_export( # Used to get rid of lint type error. assert out_spec is not None + assert orig_out_spec is not None # aot_export expect the return type to always be a tuple. if out_spec.type not in (list, tuple): @@ -1212,10 +1218,12 @@ def _strict_export( # 5. Rename constants nodes in graph module from buffers to constants _rename_constants_nodes(gm, export_graph_signature) - aten_export_artifact.out_spec = orig_out_spec - aten_export_artifact.fake_mode = dynamo_fake_mode - aten_export_artifact.module_call_specs = gm_torch_level.meta["module_call_specs"] - return aten_export_artifact + return ExportArtifact( + aten=aten_export_artifact, + out_spec=orig_out_spec, + fake_mode=dynamo_fake_mode, + module_call_specs=gm_torch_level.meta["module_call_specs"], + ) def _non_strict_export( @@ -1230,8 +1238,8 @@ def _non_strict_export( _allow_complex_guards_as_runtime_asserts: bool, _disable_forced_specializations: Optional[bool], _is_torch_jit_trace: bool, -): - out_spec = None +) -> ExportArtifact: + out_spec: Optional[TreeSpec] = None module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} @@ -1347,10 +1355,13 @@ def forward(self, *args, **kwargs): mod, aten_export_artifact.sig, aten_export_artifact.constants ) - aten_export_artifact.out_spec = out_spec - aten_export_artifact.fake_mode = fake_mode - aten_export_artifact.module_call_specs = module_call_specs - return aten_export_artifact + assert out_spec is not None + return ExportArtifact( + aten=aten_export_artifact, + out_spec=out_spec, + fake_mode=fake_mode, + module_call_specs=module_call_specs, + ) @_log_export_wrapper @@ -1452,7 +1463,7 @@ def _export( # Call the appropriate export function based on the strictness of tracing. export_func = _strict_export if strict else _non_strict_export - aten_export_artifact = export_func( + export_artifact = export_func( mod, args, kwargs, @@ -1467,12 +1478,11 @@ def _export( ) # Decompose here for readability. - gm = aten_export_artifact.gm - export_graph_signature = aten_export_artifact.sig - out_spec = aten_export_artifact.out_spec - constants = aten_export_artifact.constants - fake_mode = aten_export_artifact.fake_mode - module_call_specs = aten_export_artifact.module_call_specs + gm = export_artifact.aten.gm + export_graph_signature = export_artifact.aten.sig + out_spec = export_artifact.out_spec + fake_mode = export_artifact.fake_mode + module_call_specs = export_artifact.module_call_specs # Add forward args metadata. gm.meta["forward_arg_names"] = forward_arg_names @@ -1530,6 +1540,8 @@ def _export( _verify_stack_trace(gm) if not _is_torch_jit_trace: _verify_placeholder_names(gm, export_graph_signature) + + assert _EXPORT_MODULE_HIERARCHY is not None exported_program = ExportedProgram( root=gm, graph=gm.graph, @@ -1543,7 +1555,7 @@ def _export( module_call_signatures, ), example_inputs=(args, kwargs), - constants=aten_export_artifact.constants, + constants=export_artifact.aten.constants, ) return exported_program diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f781481c589c88..280d3da5ad6823 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -19,6 +19,8 @@ Union, ) +from torch._library.fake_class_registry import FakeScriptObject + from torch.fx.immutable_collections import immutable_dict, immutable_list if TYPE_CHECKING: @@ -217,7 +219,7 @@ def __init__( Dict[str, torch.Tensor] ] = None, # TODO: deprecate this constants: Optional[ - Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] + Dict[str, Union[torch.Tensor, FakeScriptObject, torch._C.ScriptObject]] ] = None, ): # Remove codegen related things from the graph. It should just be a flat graph. @@ -339,6 +341,7 @@ def verifier(self) -> Any: @property @compatibility(is_backward_compatible=False) def dialect(self) -> str: + assert self._verifier is not None return self._verifier.dialect @property From b94c52dd29652d2bee39a8598c9e3f061226ae3e Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 14 Jun 2024 18:23:25 +0000 Subject: [PATCH 014/171] [GHF] Refuse merge to non-default branch (#128710) Unless PR is ghstack one Test plan: ``` % GITHUB_TOKEN=$(gh auth token) python3 -c "from trymerge import GitHubPR; pr=GitHubPR('pytorch', 'pytorch', 128591); print(pr.base_ref(), pr.default_branch())" release/2.4 main ``` Fixes: https://github.com/pytorch/test-infra/issues/5339 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128710 Approved by: https://github.com/seemethere, https://github.com/atalman --- .github/scripts/test_trymerge.py | 3 +++ .github/scripts/trymerge.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/.github/scripts/test_trymerge.py b/.github/scripts/test_trymerge.py index ec3e69b706f87a..26fc5e7a0eaacb 100755 --- a/.github/scripts/test_trymerge.py +++ b/.github/scripts/test_trymerge.py @@ -180,6 +180,9 @@ def mock_gh_get_info() -> Any: return { "closed": False, "isCrossRepository": False, + "headRefName": "foo", + "baseRefName": "bar", + "baseRepository": {"defaultBranchRef": {"name": "bar"}}, "files": {"nodes": [], "pageInfo": {"hasNextPage": False}}, "changedFiles": 0, } diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 6a6d080a9b3af8..6ece48aadbba9f 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -2330,6 +2330,15 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None: dry_run=args.dry_run, ) return + if not pr.is_ghstack_pr() and pr.base_ref() != pr.default_branch(): + gh_post_pr_comment( + org, + project, + args.pr_num, + f"PR targets {pr.base_ref()} rather than {pr.default_branch()}, refusing merge request", + dry_run=args.dry_run, + ) + return if args.check_mergeability: if pr.is_ghstack_pr(): From a6bd154a42a91795c66305d07534510735c3a2a7 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Thu, 13 Jun 2024 15:16:04 -0700 Subject: [PATCH 015/171] [inductor] Support mm decomps for matrices with unbacked sizes (#128655) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128655 Approved by: https://github.com/jansel --- test/inductor/test_unbacked_symints.py | 42 ++++++++++++++++++++++++-- torch/_inductor/decomposition.py | 28 +++++++++++------ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 60ce45317238e1..45ecfcc5fb94f2 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -1,16 +1,16 @@ # Owner(s): ["module: inductor"] +import functools import unittest import torch - from torch._dynamo import config as dynamo_config from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_big_gpu from torch.testing import make_tensor from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.common_utils import IS_LINUX, parametrize from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, skipCUDAIf @@ -214,6 +214,44 @@ def fn(x, y, repeats): torch.testing.assert_close(actual, expected) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + @dynamo_config.patch({"capture_scalar_outputs": True}) + @parametrize( + "torch_fn", [torch.mm, torch.bmm, torch.addmm], name_fn=lambda fn: fn.__name__ + ) + @parametrize("coordinate_descent_tuning", [True, False], name_fn=str) + def test_mm_and_friends(self, device, torch_fn, coordinate_descent_tuning): + if torch_fn == torch.addmm: + torch_fn = functools.partial(torch_fn, torch.ones(1, device=device)) + + def fn(x, w, repeats, is_bmm): + u0 = repeats.item() + torch._check_is_size(u0) + + x_unbacked = x.expand(u0, 32) + w_unbacked = w.expand(32, u0) + if is_bmm: + # Make sure inputs are batched. + x_unbacked = x_unbacked.expand(10, *x_unbacked.shape) + w_unbacked = w_unbacked.expand(10, *w_unbacked.shape) + + return torch_fn(x_unbacked, w_unbacked) + + example_inputs = ( + torch.randn(1, 32, device=device), + torch.randn(32, 1, device=device), + torch.tensor(100, device=device), + torch_fn == torch.bmm, + ) + with inductor_config.patch( + { + # coordinate_descent_tuning has its own path during decomp + "coordinate_descent_tuning": coordinate_descent_tuning, + } + ): + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + instantiate_device_type_tests( TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu") diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 7cfb2a75ba97ba..ca46c259fff751 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -28,6 +28,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) +from torch.fx.experimental.symbolic_shapes import definitely_true, guard_size_oblivious from . import config, inductor_prims from .utils import ( @@ -202,11 +203,15 @@ def round_dec(x, decimals=0): @pw_cast_for_opmath def bmm(self, batch2): if config.coordinate_descent_tuning: - if self.shape[1] == 1 or batch2.shape[2] == 1: + if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( + batch2.shape[2] == 1 + ): out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) return out if self.device.type == "cpu": - if self.size(1) == 1 and batch2.size(-1) == 1: + if guard_size_oblivious(self.size(1) == 1) and guard_size_oblivious( + batch2.size(-1) == 1 + ): counters["inductor"]["decompose_bmm"] += 1 return torch.sum( self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True @@ -218,13 +223,19 @@ def bmm(self, batch2): @pw_cast_for_opmath def addmm(self, mat1, mat2, beta=1, alpha=1): if self.device.type == "cpu": - if mat1.size(0) == 1 and mat2.size(-1) == 1: + if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious( + mat2.size(-1) == 1 + ): counters["inductor"]["decompose_addmm"] += 1 out = torch.sum( mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True ).unsqueeze(0) return alpha * out + beta * self - if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16: + if ( + guard_size_oblivious(mat1.size(0) == 1) + and definitely_true(mat2.size(0) <= 16) + and definitely_true(mat2.size(1) <= 16) + ): counters["inductor"]["decompose_addmm"] += 1 out = (mat1.T * mat2).sum(dim=0, keepdim=True) return alpha * out + beta * self @@ -234,15 +245,12 @@ def addmm(self, mat1, mat2, beta=1, alpha=1): @register_decomposition([aten.mm]) @pw_cast_for_opmath def mm(self, input2): - from torch.fx.experimental.symbolic_shapes import ( - definitely_true, - guard_size_oblivious, - ) - # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) if config.coordinate_descent_tuning: - if self.shape[0] == 1 or input2.shape[1] == 1: + if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( + input2.shape[1] == 1 + ): return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) if self.device.type == "cpu": if ( From 27458cc097fe0113ae67136d27d5437d7d920860 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 14 Jun 2024 11:23:27 -0700 Subject: [PATCH 016/171] [BE] Refactor repeated code in test_weight_norm (#128726) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128726 Approved by: https://github.com/kit1980 --- test/test_mps.py | 76 +++++++++++++++--------------------------------- 1 file changed, 23 insertions(+), 53 deletions(-) diff --git a/test/test_mps.py b/test/test_mps.py index 0cb479680cf4a0..9ca0cbafebf28a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2855,6 +2855,25 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run track_running_stats=track_running_stats, test_module=test_module) def test_weight_norm(self): + def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim): + cpu_norm = torch.nn.utils.weight_norm(cpu_model, dim=dim) + norm = torch.nn.utils.weight_norm(model, dim=dim) + + cpu_out = cpu_norm(cpu_x) + out = norm(x) + + self.assertEqual(cpu_out, out) + + cpu_grad = torch.randn(cpu_out.shape) + grad = cpu_grad.to('mps') + cpu_out.backward(gradient=cpu_grad) + out.backward(gradient=grad) + + self.assertEqual(cpu_model.weight_g.grad, model.weight_g.grad) + self.assertEqual(cpu_model.weight_v.grad, model.weight_v.grad) + + self.assertEqual(x.grad, cpu_x.grad) + def helper(dim, layer='linear', dtype=torch.float32): # linear layer if layer == 'linear': @@ -2875,24 +2894,7 @@ def helper(dim, layer='linear', dtype=torch.float32): cpu_linear.bias.copy_(cpu_bias) linear.weight.copy_(weight) linear.bias.copy_(bias) - - cpu_norm = torch.nn.utils.weight_norm(cpu_linear, dim=dim) - norm = torch.nn.utils.weight_norm(linear, dim=dim) - - cpu_out = cpu_norm(cpu_x) - out = norm(x) - - self.assertEqual(cpu_out, out) - - cpu_grad = torch.randn(cpu_out.shape) - grad = cpu_grad.to('mps') - cpu_out.backward(gradient=cpu_grad) - out.backward(gradient=grad) - - self.assertEqual(cpu_linear.weight_g.grad, linear.weight_g.grad) - self.assertEqual(cpu_linear.weight_v.grad, linear.weight_v.grad) - - self.assertEqual(x.grad, cpu_x.grad) + validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim) # conv layer if layer == 'conv': @@ -2906,25 +2908,9 @@ def helper(dim, layer='linear', dtype=torch.float32): conv.weight.copy_(cpu_conv.weight) conv.bias.copy_(cpu_conv.bias) - cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim) - norm = torch.nn.utils.weight_norm(conv, dim=dim) - - cpu_out = cpu_norm(cpu_x) - out = norm(x) - - self.assertEqual(cpu_out, out) - - cpu_grad = torch.randn(cpu_out.shape) - grad = cpu_grad.to('mps') - cpu_out.backward(gradient=cpu_grad) - out.backward(gradient=grad) - - self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad) - self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad) - - self.assertEqual(x.grad, cpu_x.grad) + validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) - # conv layer + # conv3d layer if layer == 'conv3d': cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True) x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -2936,23 +2922,7 @@ def helper(dim, layer='linear', dtype=torch.float32): conv.weight.copy_(cpu_conv.weight) conv.bias.copy_(cpu_conv.bias) - cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim) - norm = torch.nn.utils.weight_norm(conv, dim=dim) - - cpu_out = cpu_norm(cpu_x) - out = norm(x) - - self.assertEqual(cpu_out, out) - - cpu_grad = torch.randn(cpu_out.shape) - grad = cpu_grad.to('mps') - cpu_out.backward(gradient=cpu_grad) - out.backward(gradient=grad) - - self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad) - self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad) - - self.assertEqual(x.grad, cpu_x.grad) + validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) helper(0, layer='linear') helper(1, layer='linear') From 9035fff2de29ffd836f432580fa91405041756c1 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 14 Jun 2024 11:23:30 -0700 Subject: [PATCH 017/171] [BE] Do not test deprecated `torch.nn.utils.weight_norm` (#128727) Test `torch.nn.utils.parametrizations.weight_norm` instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/128727 Approved by: https://github.com/kit1980 ghstack dependencies: #128726 --- test/test_mps.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_mps.py b/test/test_mps.py index 9ca0cbafebf28a..275013f20effcb 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2856,8 +2856,8 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run def test_weight_norm(self): def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim): - cpu_norm = torch.nn.utils.weight_norm(cpu_model, dim=dim) - norm = torch.nn.utils.weight_norm(model, dim=dim) + cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim) + norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim) cpu_out = cpu_norm(cpu_x) out = norm(x) @@ -2869,8 +2869,8 @@ def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim): cpu_out.backward(gradient=cpu_grad) out.backward(gradient=grad) - self.assertEqual(cpu_model.weight_g.grad, model.weight_g.grad) - self.assertEqual(cpu_model.weight_v.grad, model.weight_v.grad) + self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad) + self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad) self.assertEqual(x.grad, cpu_x.grad) From d50712e5e35180ff7496d13652f33adbde58dcd3 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Fri, 14 Jun 2024 19:45:55 +0000 Subject: [PATCH 018/171] [PT2] add inductor log for unbind_stack_pass (#128684) Summary: Currently, we do not log the pass. To better enable pattern hit inspection, we enable it. Test Plan: see signal Differential Revision: D58571992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128684 Approved by: https://github.com/dshi7 --- torch/_inductor/fx_passes/split_cat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 71acbf5cb6739d..9c7f94d755b3d1 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -603,6 +603,7 @@ def simplify( graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] ) self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters["inductor"]["unbind_stack_pass"] += 1 def get_user_input_list( self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] From 2e5366fbc04f819c62612e8c56fb786b43c1c67d Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Fri, 14 Jun 2024 19:48:46 +0000 Subject: [PATCH 019/171] Extended Module Tracker (#128508) This is an extension of [ModuleTracker](https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py) with added features and bug fixes. 1. Allows installing user-defined hooks to be called in pre-fw, post-fw, pre-bw and post-bw hooks of the ``ModTracker``. 2. Adds a function ``get_known_fqn`` that retrieves the fqn of the module as tracked by the ``ModTracker``. 3. Only registers the multi-grad hooks if we are in the forward pass. This is important because, a module's pre-fw and post-fw hooks get called in the backward during AC and we do not want to register multi-grad hooks in this case. 4. Sets the kwarg ``always_call=True`` for post-fw hooks, so that they are called post AC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128508 Approved by: https://github.com/wanchaol --- test/distributed/_tools/test_mod_tracker.py | 140 ++++++++++++ torch/distributed/_tools/__init__.py | 1 + torch/distributed/_tools/mod_tracker.py | 232 ++++++++++++++++++++ 3 files changed, 373 insertions(+) create mode 100644 test/distributed/_tools/test_mod_tracker.py create mode 100644 torch/distributed/_tools/mod_tracker.py diff --git a/test/distributed/_tools/test_mod_tracker.py b/test/distributed/_tools/test_mod_tracker.py new file mode 100644 index 00000000000000..7625e8456f5a84 --- /dev/null +++ b/test/distributed/_tools/test_mod_tracker.py @@ -0,0 +1,140 @@ +# Owner(s): ["module: unknown"] + +from copy import copy + +import torch +from torch.distributed._tools.mod_tracker import ModTracker +from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo + + +class TestModTracker(TestCase): + # "https://github.com/pytorch/pytorch/issues/127112 + @xfailIfTorchDynamo + def test_module_hierarchy(self): + seen_fw = [] + seen_bw = [] + + class Foo(torch.nn.Module): + def forward(self, x): + x = x["a"].relu_() + seen_fw.append((copy(tracker.parents), tracker.is_bw)) + x.register_hook( + lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) + ) + return {"a": torch.mm(x, x)} + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = Foo() + self.b = torch.nn.ModuleDict({"nest": Foo()}) + self.c = torch.nn.ModuleList([Foo()]) + + def forward(self, x): + x = self.c[0](x) + return self.b["nest"](self.a(x)) + + mod = Mod() + + with ModTracker() as tracker: + mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ + "a" + ].sum().backward() + mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ + "a" + ].sum().backward() + + self.assertEqual( + seen_fw, + [ + ({"Global", "Mod", "Mod.c.0"}, False), + ({"Global", "Mod", "Mod.a"}, False), + ({"Global", "Mod", "Mod.b.nest"}, False), + ({"Global", "Mod", "Mod.c.0"}, False), + ({"Global", "Mod", "Mod.a"}, False), + ({"Global", "Mod", "Mod.b.nest"}, False), + ], + ) + + self.assertEqual( + seen_bw, + [ + ({"Global", "Mod", "Mod.b.nest"}, True), + ({"Global", "Mod", "Mod.a"}, True), + ({"Global", "Mod", "Mod.c.0"}, True), + ({"Global", "Mod", "Mod.b.nest"}, True), + ({"Global", "Mod", "Mod.a"}, True), + ({"Global", "Mod", "Mod.c.0"}, True), + ], + ) + + def test_bw_detection(self): + mod = torch.nn.Linear(2, 2) + + with ModTracker() as tracker: + mod(torch.rand(2, requires_grad=True)).sum().backward() + self.assertFalse(tracker.is_bw) + self.assertEqual(tracker.parents, {"Global"}) + + @xfailIfTorchDynamo + def test_user_hooks(self): + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + self.foo = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.foo(x).relu_() + + mt = ModTracker() + test_op = [] + + def hook(mod, hook_name): + mfqn = mt.get_known_fqn(mod) if mod is not None else None + test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) + + mod = Bar() + + mt.register_user_hooks( + lambda m, inp: hook(m, "pre_fw"), + lambda m, inp, op: hook(m, "post_fw"), + lambda m, gop: hook(m, "pre_bw"), + lambda m, ginp: hook(m, "post_bw"), + ) + with mt: + mod(torch.rand(10, 10, requires_grad=True)).sum().backward() + expected_op = [ + ("pre_fw", "Bar", True, False), + ("pre_fw", "Bar.foo", True, False), + ("post_fw", "Bar.foo", True, False), + ("post_fw", "Bar", True, False), + ("pre_bw", "Bar", True, True), + ("pre_bw", "Bar.foo", True, True), + ("post_bw", "Bar", True, True), + ("post_bw", "Bar.foo", True, True), + ] + self.assertEqual(test_op, expected_op) + + with self.assertRaises(AssertionError): + mt.register_user_hooks(lambda x, y: x, None, None, None) + + test_op.clear() + with mt: + loss = mod(torch.rand(10, 10, requires_grad=True)).sum() + del mod + loss.backward() + expected_op = [ + ("pre_fw", "Bar", True, False), + ("pre_fw", "Bar.foo", True, False), + ("post_fw", "Bar.foo", True, False), + ("post_fw", "Bar", True, False), + ("pre_bw", None, False, True), + ("pre_bw", None, False, True), + ("post_bw", None, False, True), + ("post_bw", None, False, True), + ] + self.assertEqual(test_op, expected_op) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tools/__init__.py b/torch/distributed/_tools/__init__.py index eda274b5724f0f..b8a8950f9cd0a0 100644 --- a/torch/distributed/_tools/__init__.py +++ b/torch/distributed/_tools/__init__.py @@ -1 +1,2 @@ from .memory_tracker import MemoryTracker +from .mod_tracker import ModTracker diff --git a/torch/distributed/_tools/mod_tracker.py b/torch/distributed/_tools/mod_tracker.py new file mode 100644 index 00000000000000..07ebc32cdbcc63 --- /dev/null +++ b/torch/distributed/_tools/mod_tracker.py @@ -0,0 +1,232 @@ +# mypy: allow-untyped-defs +import warnings +import weakref +from typing import Callable, Optional, Set + +import torch +from torch.autograd.graph import register_multi_grad_hook +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from torch.utils._pytree import tree_flatten + + +__all__ = ["ModTracker"] + + +class ModTracker: + """ + ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = torch.nn.Linear(2, 2) + + with ModTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return torch.mm(m1, m2.t()) + bias + torch.nn.functional.linear = my_linear + + mod(torch.rand(2, 2)) + + """ + + parents: Set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self): + self.parents = {"Global"} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + torch.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return torch._C._current_graph_task_id() != -1 + + def get_known_fqn(self, mod): + """ + Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. + """ + return self._known_modules.get(mod, None) + + def register_user_hooks( + self, + pre_fw_hook: Optional[Callable] = None, + post_fw_hook: Optional[Callable] = None, + pre_bw_hook: Optional[Callable] = None, + post_bw_hook: Optional[Callable] = None, + ): + """ + Registers user-specified hooks to be called before/after the forward/backward pass for each + module tracked by the ``ModTracker``. One or more can be ``None``. + Args: + pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the + module. It should have the following signature: + pre_fw_hook (module, input) -> None + post_fw_hook (Callable, optional): A hook to be called after the forward pass for the + module. It should have the following signature: + post_fw_hook (module, input, output) -> None + pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of + the module that require gradients. It should have the following signature: + pre_bw_hook (module, grad_output) -> None + post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of + the module that require gradients. It should have the following signature: + post_bw_hook (module, grad_input) -> None + Raises: + AssertionError: If a new hook is provided when one is already registered. + Note: + If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will + will receive None as the module argument. + The module fqn will be present in the ``parents`` attribute when each of the hooks is called. + Hooks are intended to be used as markers only not to modify the inputs/outputs. + """ + + def set_hook(hook, user_hook, hook_name): + if hook is not None and user_hook is not None: + raise AssertionError( + f"Only one {hook_name} can be registered at a time" + f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" + ) + return hook + + self._user_pre_fw_hook = set_hook( + pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" + ) + self._user_post_fw_hook = set_hook( + post_fw_hook, self._user_post_fw_hook, "post_fw_hook" + ) + self._user_pre_bw_hook = set_hook( + pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" + ) + self._user_post_bw_hook = set_hook( + post_bw_hook, self._user_post_bw_hook, "post_bw_hook" + ) + + def clear_user_hooks(self): + """ + Clears the user specified hooks registered with ``register_user_hooks`` + """ + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, w_mod, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents and not self.is_bw: + + def custom_formatwarning(msg, category, filename, lineno, line=None): + return f"{filename}:{lineno}: {category.__name__}: {msg} \n" + + warnings.formatwarning = custom_formatwarning + warnings.warn( + "The module hierarchy tracking maybe be messed up." + " Please file a bug to PyTorch, if it is the case." + ) + self.parents.add(name) + + if self._user_pre_bw_hook is not None and is_bw: + self._user_pre_bw_hook(w_mod(), args) + + return fn + + def _get_pop_fn(self, w_mod, name, is_bw): + def fn(*args): + if self._user_post_bw_hook is not None and is_bw: + self._user_post_bw_hook(w_mod(), args) + + if name in self.parents: + self.parents.remove(name) + elif not is_bw: + # Due to some input/output not requiring gradients, we cannot enforce + # proper nesting in backward + raise RuntimeError( + "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + ) + + return fn + + def _fw_pre_hook(self, mod, input): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + self._get_append_fn(w_mod, name, False)() + if self._user_pre_fw_hook is not None: + self._user_pre_fw_hook(mod, input) + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) + + def _fw_post_hook(self, mod, input, output): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + if self._user_post_fw_hook is not None: + self._user_post_fw_hook(mod, input, output) + self._get_pop_fn(w_mod, name, False)() + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook(tensors, self._get_append_fn(w_mod, name, True)) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook( + self._fw_post_hook, always_call=True + ) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() From 9c773321160fbf9b578f4ab2cfc592d76f4f89e8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 14 Jun 2024 09:41:11 -0700 Subject: [PATCH 020/171] [torch.compile][ci] Flaky models in CI (similar to DISABLED_TEST) (#128715) These models are really flaky. I went into the CI machine and ran the model many times, sometime it fails, sometimes it passes. Even Pytorch-eager results change from run to run, so the accuracy comparison is fundamentally broken/non-deterministic. I am hitting these issues more frequently in inlining work. There is nothing wrong with inlining, I think these models are on the edge of already-broken accuracy measurement, and inlining is just pushing it in more broken direction. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128715 Approved by: https://github.com/eellison --- benchmarks/dynamo/check_accuracy.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/benchmarks/dynamo/check_accuracy.py b/benchmarks/dynamo/check_accuracy.py index da82f789a10388..8cbc18658ee6af 100644 --- a/benchmarks/dynamo/check_accuracy.py +++ b/benchmarks/dynamo/check_accuracy.py @@ -6,6 +6,14 @@ import pandas as pd +# Hack to have something similar to DISABLED_TEST. These models are flaky. + +flaky_models = { + "yolov3", + "gluon_inception_v3", +} + + def get_field(csv, model_name: str, field: str): try: return csv.loc[csv["name"] == model_name][field].item() @@ -25,6 +33,13 @@ def check_accuracy(actual_csv, expected_csv, expected_filename): status = "PASS" if expected_accuracy == "pass" else "XFAIL" print(f"{model:34} {status}") continue + elif model in flaky_models: + if accuracy == "pass": + # model passed but marked xfailed + status = "PASS_BUT_FLAKY:" + else: + # model failed but marked passe + status = "FAIL_BUT_FLAKY:" elif accuracy != "pass": status = "FAIL:" failed.append(model) From 1aafb9eb907e811dedb5fcdcd8af0d4649ee1406 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 14 Jun 2024 09:41:11 -0700 Subject: [PATCH 021/171] [dynamo][yolov3] Track UnspecializedNNModuleVariable for mutation (#128269) Fixes https://github.com/pytorch/pytorch/issues/101168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128269 Approved by: https://github.com/jansel ghstack dependencies: #128715 --- .../aot_eager_torchbench_inference.csv | 2 +- .../aot_eager_torchbench_training.csv | 2 +- .../aot_inductor_torchbench_inference.csv | 2 +- .../cpu_inductor_torchbench_freezing_inference.csv | 2 +- .../cpu_inductor_torchbench_inference.csv | 2 +- .../cu124/aot_inductor_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_inference.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_torchbench_inference.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../inductor_torchbench_inference.csv | 2 +- .../inductor_torchbench_training.csv | 2 +- test/dynamo/test_modules.py | 13 +++++++++++++ torch/_dynamo/output_graph.py | 4 +++- 17 files changed, 31 insertions(+), 16 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index fb98fd4e523944..9863aa7da6a252 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index a62389fad7dadf..4055eda462c5b4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 65c905837c2a73..1624d6dc7973f7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,fail_to_run,0 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv index 3942e3a2f34318..3af215541c1dfb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -338,4 +338,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index fcd87f4d245450..a497fb45d7d48c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -338,4 +338,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv index 65c905837c2a73..1624d6dc7973f7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,fail_to_run,0 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 3507221429fa2c..3aecea06b53009 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index c05d5f175c7db4..a9a7e396c0f3bb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index ce271939b18cc9..5ffc870a8dec1f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -298,4 +298,4 @@ vision_maskrcnn,pass,28 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 5c2f70cba3d49f..c167ea680d2ca6 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -374,4 +374,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 5d45a95c8f19b2..c25fa947133749 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -282,4 +282,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index fb98fd4e523944..9863aa7da6a252 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index a62389fad7dadf..4055eda462c5b4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 1e82186e1a88e7..74549205d747a1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -378,4 +378,4 @@ vision_maskrcnn,pass,17 -yolov3,pass,2 +yolov3,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index a62389fad7dadf..4055eda462c5b4 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -286,4 +286,4 @@ vision_maskrcnn,pass,34 -yolov3,pass,9 +yolov3,pass,8 diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index dbfef8af438657..4e5d40c26524ac 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -2512,6 +2512,19 @@ def forward(self, x): self.assertEqual(eager_res, optim_res) self.assertEqual(cnt.frame_count, 1) + def test_module_setattr(self): + models = torch.nn.Sequential(torch.nn.Linear(3, 3)) + models[0].abc = False + + def run(): + models[0].abc = True + x = torch.randn(1, 3) + return models(x) + + run = torch.compile(run, fullgraph=True) + run() + self.assertTrue(models[0].abc) + def test_assign_does_not_exist(self): class MyModule(torch.nn.Module): def forward(self, x): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 618558d93a242a..d7dac331c74358 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -752,7 +752,9 @@ def register_attr_or_module( **options, ): if is_dynamic_nn_module(target, self.root_tx.export): - return variables.UnspecializedNNModuleVariable(target, **options) + # Instead of returning UnspecializedNNModuleVariable, call + # VariableBuilder so that it is tracked for mutation. + return VariableBuilder(self.current_tx, **options)(target) options = dict(options) assert "source" in options From 4b96575a09b167489539d8916e86cf6c13e9544d Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 14 Jun 2024 06:44:20 -0700 Subject: [PATCH 022/171] [dynamo][aot autograd] Silently disable default saved tensor hooks during tracing (#123196) FIXES #113263. Same idea as in https://github.com/pytorch/pytorch/pull/113417, but we need a more intrusive C API to silently nop default saved tensor hooks, in order to support user-code that use torch.autograd.disable_saved_tensors_hooks (see test_unpack_hooks_can_be_disabled). We mock the output of get_hooks while leaving push/pop untouched. For compiled autograd, we're firing pack hooks once and unpack hooks twice right now, I'll look into this separately from this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123196 Approved by: https://github.com/soulitzer --- aten/src/ATen/SavedTensorHooks.cpp | 21 +++--- aten/src/ATen/SavedTensorHooks.h | 18 ++++- test/dynamo/test_higher_order_ops.py | 17 ++--- test/dynamo/test_repros.py | 61 +++++++++++++++ ..._nested_checkpoint_kwargs_early_stop_False | 0 ...tensor_inputs_and_outputs_early_stop_False | 0 ...point_reentrant_backwards_early_stop_False | 0 ...ted_checkpoint_same_graph_early_stop_False | 0 test/inductor/test_compiled_autograd.py | 1 + test/test_autograd.py | 28 +++++++ test/test_functionalization_of_rng_ops.py | 4 +- torch/_C/_autograd.pyi | 1 + torch/_dynamo/__init__.py | 1 + torch/_dynamo/utils.py | 10 +++ torch/_dynamo/variables/builder.py | 75 ++++++++++--------- torch/_functorch/aot_autograd.py | 10 ++- torch/csrc/autograd/init.cpp | 3 + .../autograd/python_saved_variable_hooks.cpp | 9 +-- 18 files changed, 193 insertions(+), 66 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_False delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_False delete mode 100644 test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_False diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp index f2fb0642eb34c5..7aa9b0f02ea365 100644 --- a/aten/src/ATen/SavedTensorHooks.cpp +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -35,6 +35,12 @@ void SavedTensorDefaultHooks::enable() { tls.disabled_error_message = c10::nullopt; } +/* static */ bool SavedTensorDefaultHooks::set_tracing(bool is_tracing) { + bool prior = tls.is_tracing; + tls.is_tracing = is_tracing; + return prior; +} + const std::optional& SavedTensorDefaultHooks::get_disabled_error_message() { return tls.disabled_error_message; } @@ -59,25 +65,20 @@ void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_h tls.stack.emplace(pack_hook, unpack_hook); } -void SavedTensorDefaultHooks::pop_hooks() { +std::pair SavedTensorDefaultHooks::pop_hooks() { // Reference counting is handled by the caller of `pop_hooks` TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty()); + std::pair hooks = tls.stack.top(); tls.stack.pop(); + return hooks; } std::pair SavedTensorDefaultHooks::get_hooks() { - if (!is_initialized || tls.stack.empty()) { + // For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime] + if (!is_initialized || tls.stack.empty() || tls.is_tracing) { return std::make_pair(nullptr, nullptr); } return tls.stack.top(); } -std::stack> SavedTensorDefaultHooks::get_stack() { - return tls.stack; -} - -void SavedTensorDefaultHooks::set_stack(std::stack> stack_) { - tls.stack = std::move(stack_); -} - } diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h index 6ad46a8334c3f0..b69b9c25e8e6a5 100644 --- a/aten/src/ATen/SavedTensorHooks.h +++ b/aten/src/ATen/SavedTensorHooks.h @@ -22,17 +22,18 @@ struct TORCH_API SavedTensorDefaultHooksTLS { // We did this for efficiency (so we didn't have to keep a separate bool // around) std::optional disabled_error_message; + + // See NOTE: [Deferring tensor pack/unpack hooks until runtime] + bool is_tracing = false; }; } // namespace impl struct TORCH_API SavedTensorDefaultHooks { static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook); - static void pop_hooks(); + static std::pair pop_hooks(); static std::pair get_hooks(); static void lazy_initialize(); - static std::stack> get_stack(); - static void set_stack(std::stack>); static const impl::SavedTensorDefaultHooksTLS& get_tls_state(); static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls); @@ -42,11 +43,20 @@ struct TORCH_API SavedTensorDefaultHooks { // hooks, especially if their feature does not work with it. If they are // disabled, then the following will raise an error: // - Attempting to push_hooks - // - calling disable(message) with a non-zero stack (from get_stack) size + // - calling disable(message) with a non-zero stack (hooks) size static void disable(const std::string& error_message); static void enable(); static bool is_enabled(); static const std::optional& get_disabled_error_message(); + + // NOTE: [Deferring tensor pack/unpack hooks until runtime] + // To preserve eager semantics of pack/unpack hooks firing only once per saved + // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using + // disable() would loud error at trace time, and pushing a no-op hook would + // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx. + // To do so, we disable these hooks during tracing. See + // https://github.com/pytorch/pytorch/issues/113263. + static bool set_tracing(bool is_tracing); }; } // namespace at diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 410317d33a1472..dca6d28d1912dd 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -45,7 +45,8 @@ def check_dynamic_shape_capture(): def count_ops(gm, args, freq, op): - assert [node.target for node in gm.graph.nodes].count(op) == freq + actual = [node.target for node in gm.graph.nodes].count(op) + assert actual == freq, f"expected={freq}, actual={actual}" return gm @@ -6049,9 +6050,7 @@ def fn(x, y): y = torch.randn(4, 4, requires_grad=True) fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) - bw_compiler = functools.partial( - count_ops, freq=3, op=torch.ops.aten.mm.default - ) # mm recomputed in the bwd + bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @@ -6074,9 +6073,7 @@ def fn(x, y): y = torch.randn(4, 4, requires_grad=True) fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) - bw_compiler = functools.partial( - count_ops, freq=3, op=torch.ops.aten.mm.default - ) # mm recomputed in the bwd + bw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x, y) @@ -6097,8 +6094,9 @@ def fn(x, y): fw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default ) + # philox_rand is passed from fwd bw_compiler = functools.partial( - count_ops, freq=1, op=torch.ops.rngprims.philox_rand.default + count_ops, freq=0, op=torch.ops.rngprims.philox_rand.default ) backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate( @@ -6178,8 +6176,9 @@ def fn(x): fw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) + # sigmoid passed from fwd bw_compiler = functools.partial( - count_ops, freq=1, op=torch.ops.aten.sigmoid.default + count_ops, freq=0, op=torch.ops.aten.sigmoid.default ) backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) self._validate(fn, backend, x) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 4de748ccf7a152..c30210a398407b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1079,6 +1079,67 @@ def f(x): out_test.sum().backward() self.assertEqual(leaf.grad, leaf_test.grad) + # https://github.com/pytorch/pytorch/issues/113263 + def test_unpack_hooks_dont_run_during_tracing(self): + def f(x, y): + return x * y + + f_compiled = torch.compile(f, backend="aot_eager") + + pack_count = 0 + unpack_count = 0 + + def pack_hook(x): + nonlocal pack_count + pack_count += 1 + return x + + # unpack hook shouldn't run during compilation, while we trace the forward + def unpack_hook(x): + nonlocal unpack_count + unpack_count += 1 + return x + + x = torch.ones(4, requires_grad=True) + y = torch.ones(4, requires_grad=False) + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + out_test = f_compiled(x, y) + self.assertEqual(pack_count, 1) + self.assertEqual(unpack_count, 0) + out_test.sum().backward() + self.assertEqual(pack_count, 1) + self.assertEqual(unpack_count, 1) + + # https://github.com/pytorch/pytorch/issues/113263 + def test_unpack_hooks_can_be_disabled(self): + def f(x, y): + return x * y + + f_compiled = torch.compile(f, backend="aot_eager") + + x = torch.ones(4, requires_grad=True) + y = torch.ones(4, requires_grad=False) + with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): + out_test = f_compiled(x, y) + out_test.sum().backward() + + # https://github.com/pytorch/pytorch/issues/113263 + def test_disabling_unpack_hooks_within_compiled_region(self): + def g(z): + with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): + return z + 5 + + def f(x, y): + z = x * y + return g(z) + + f_compiled = torch.compile(f, backend="aot_eager") + + x = torch.ones(4, requires_grad=True) + y = torch.ones(4, requires_grad=False) + out_test = f_compiled(x, y) + out_test.sum().backward() + # See https://github.com/pytorch/pytorch/issues/97745 def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self): def f(a, b): diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_kwargs_early_stop_False deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_non_tensor_inputs_and_outputs_early_stop_False deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_reentrant_backwards_early_stop_False deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_False b/test/dynamo_expected_failures/TestNestedCheckpoint.test_nested_checkpoint_same_graph_early_stop_False deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index e09928cf5576a9..a3dfcb59f2fddb 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2223,6 +2223,7 @@ def wrap_test_class(orig_cls): "test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function "test_setitem", # AssertionError: Tensor-likes are not close! "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) + "test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads } diff --git a/test/test_autograd.py b/test/test_autograd.py index ce5b4234b8291b..c133ae95b4b3da 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -9335,6 +9335,34 @@ def backward(ctx, grad_out): out = Func.apply(a) out.backward() + def test_unpack_hooks_exec_count(self): + def f(x, y): + return x * y + + pack_count = 0 + unpack_count = 0 + + def pack_hook(x): + nonlocal pack_count + pack_count += 1 + return x + + # unpack hook shouldn't run during compilation, while we trace the forward + def unpack_hook(x): + nonlocal unpack_count + unpack_count += 1 + return x + + x = torch.ones(4, requires_grad=True) + y = torch.ones(4, requires_grad=False) + with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): + out_test = f(x, y) + self.assertEqual(pack_count, 1) + self.assertEqual(unpack_count, 0) + out_test.sum().backward() + self.assertEqual(pack_count, 1) + self.assertEqual(unpack_count, 1) + def test_save_on_cpu_and_checkpoint(self): a = torch.randn(2, 2, requires_grad=True) diff --git a/test/test_functionalization_of_rng_ops.py b/test/test_functionalization_of_rng_ops.py index bba22ff34a0b05..1c9e8e6cecc829 100644 --- a/test/test_functionalization_of_rng_ops.py +++ b/test/test_functionalization_of_rng_ops.py @@ -298,9 +298,9 @@ def fn(x, y): torch.cuda.manual_seed(123) ref = fn(x, y) - # With checkpointing we should recompute dropout in bwd, and should see philox_rand + # With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd fwd_compiler = functools.partial(count_philox_rand, freq=1) - bwd_compiler = functools.partial(count_philox_rand, freq=1) + bwd_compiler = functools.partial(count_philox_rand, freq=0) aot_fn = aot_function(fn, fwd_compiler, bwd_compiler) # We cant check accuracy here because rand_like generated different rand numbers than dropout res = aot_fn(x, y) diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 05a791725608e8..3a5c85826007de 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -115,6 +115,7 @@ def _profiler_type() -> ActiveProfilerType: ... def _saved_tensors_hooks_enable() -> None: ... def _saved_tensors_hooks_disable(message: str) -> None: ... def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ... +def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ... class CreationMeta(Enum): DEFAULT = ... diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 9690f21ad27a4d..7c3680ca3c05fa 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -85,6 +85,7 @@ def reset() -> None: callback_handler.clear() GenerationTracker.clear() torch._dynamo.utils.warn_once_cache.clear() + torch._C._autograd._saved_tensors_hooks_set_tracing(False) def reset_code_caches() -> None: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 91620651179e6f..9fa70e0c98d524 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2773,3 +2773,13 @@ def strip_color_from_string(text): # This regular expression matches ANSI escape codes ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") return ansi_escape.sub("", text) + + +@contextlib.contextmanager +def _disable_saved_tensors_hooks_during_tracing(): + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + try: + prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True) + yield + finally: + torch._C._autograd._saved_tensors_hooks_set_tracing(prior) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index cbd42907f8ba80..2f7b50d36c26e1 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1819,40 +1819,47 @@ def _clone_input(value): return value - # with preserve_rng_state(): - if example_value is None: - # only allow_non_graph_fake in this instance because we handle the non-fake - # cases properly below. - example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) - - # Handle recursive calls here - elif maybe_get_fake_mode(example_value) is tx.fake_mode: - pass - - elif isinstance(example_value, torch.Tensor): - if tx.export: - # The legacy behavior for real value cache with subclasses was - # to perform a clone WITHOUT preserving the subclass. It's - # not entirely clear this is what you actually want though. - with torch._C.DisableTorchFunctionSubclass(): - proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) - # NB: If we're ignoring subclass, then the expectation is you will - # take the returned TensorVariable and wrap it into a more - # accurate TensorVariable that is able to track subclass-ness; - # otherwise this is wrong! - kwargs = { - "is_tensor": target_cls in (TensorVariable, TensorWithTFOverrideVariable), - } - assert "source" in options and options["source"] is not None - kwargs["source"] = options["source"] - example_value = wrap_to_fake_tensor_and_record(example_value, tx=tx, **kwargs) - if isinstance(example_value, torch.Tensor) and ( - maybe_get_fake_mode(example_value) is not tx.fake_mode - ): - raise InternalTorchDynamoError( - "`example_value` needs to be a `FakeTensor`" - f"wrapped by this instance of Dynamo. Found: {example_value}" - ) + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): + # with preserve_rng_state(): + if example_value is None: + # only allow_non_graph_fake in this instance because we handle the non-fake + # cases properly below. + example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) + + # Handle recursive calls here + elif maybe_get_fake_mode(example_value) is tx.fake_mode: + pass + + elif isinstance(example_value, torch.Tensor): + if tx.export: + # The legacy behavior for real value cache with subclasses was + # to perform a clone WITHOUT preserving the subclass. It's + # not entirely clear this is what you actually want though. + with torch._C.DisableTorchFunctionSubclass(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + kwargs = { + "is_tensor": target_cls + in (TensorVariable, TensorWithTFOverrideVariable), + } + assert "source" in options and options["source"] is not None + kwargs["source"] = options["source"] + example_value = wrap_to_fake_tensor_and_record( + example_value, tx=tx, **kwargs + ) + if isinstance(example_value, torch.Tensor) and ( + maybe_get_fake_mode(example_value) is not tx.fake_mode + ): + raise InternalTorchDynamoError( + "`example_value` needs to be a `FakeTensor`" + f"wrapped by this instance of Dynamo. Found: {example_value}" + ) if isinstance(example_value, torch.Tensor): is_parameter = isinstance(example_value, torch.nn.Parameter) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index a94e3937f7314d..30d711e03fb58b 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -481,9 +481,17 @@ def create_aot_dispatcher_function( enable_python_dispatcher() if shape_env is not None else nullcontext() ) + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] + # If any saved tensor hooks are active, we **don't** want to trace them. + # Instead, we'll let them run at runtime, around the custom autograd.Function + # that we generate in torch.compile. with torch.autograd.set_multithreading_enabled( False - ), preserve_rng_state(), fake_mode, python_dispatcher_mode, PhiloxStateTracker(): + ), preserve_rng_state(), ( + fake_mode + ), ( + python_dispatcher_mode + ), PhiloxStateTracker(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): def process_inputs(flat_args): def convert(idx, x): diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 9eb1031ff02c05..e6a907ee2f0a40 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -372,6 +372,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { at::SavedTensorDefaultHooks::is_enabled); m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable); m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable); + m.def( + "_saved_tensors_hooks_set_tracing", + at::SavedTensorDefaultHooks::set_tracing); m.def( "_saved_tensors_hooks_get_disabled_error_message", at::SavedTensorDefaultHooks::get_disabled_error_message); diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp index ef7ae89dc34928..66b2381156aa7a 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.cpp +++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp @@ -5,8 +5,7 @@ namespace py = pybind11; -namespace torch { -namespace autograd { +namespace torch::autograd { PySavedVariableHooks::PySavedVariableHooks( py::function& pack_hook, py::function& unpack_hook) @@ -65,14 +64,13 @@ void PyDefaultSavedVariableHooks::push_hooks( } void PyDefaultSavedVariableHooks::pop_hooks() { - auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks(); + auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::pop_hooks(); TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); if (Py_IsInitialized()) { py::gil_scoped_acquire gil; Py_XDECREF(pack_hook); Py_XDECREF(unpack_hook); } - at::SavedTensorDefaultHooks::pop_hooks(); } std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { @@ -86,5 +84,4 @@ std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { return std::make_unique(pack_hook_, unpack_hook_); } -} // namespace autograd -} // namespace torch +} // namespace torch::autograd From 11de50f17ccecdc35acca712caadec1f2d6650c0 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 14 Jun 2024 12:02:47 -0700 Subject: [PATCH 023/171] [Dynamo] skip some TorchScript tests (#128731) We don't care about the Dynamo x TorchScript composition, so I'm disabling these tests (so they don't get reported as flaky). Not disabling all of the TorchScript tests yet because they have been useful to catch random bugs. Test Plan: - CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/128731 Approved by: https://github.com/williamwen42 --- test/test_jit.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index 0e99c3602cd67c..13bdd07be6cd98 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4967,6 +4967,7 @@ def test(backward=False): test(backward=True) test(backward=True) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_index(self): def consec(size, start=0): numel = torch.tensor(size).prod().item() @@ -6431,6 +6432,7 @@ def divmod_test_iterator(func, num, den): cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int))) cu.func_float_int(5.3, 0) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_math_ops(self): def checkMathWrap(func_name, num_args=1, is_float=True, **args): if is_float: @@ -6959,6 +6961,7 @@ def test_all_float_list(x): self.assertFalse(test_all_float_list([3.14, 0, 8.9])) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_number_math(self): ops_template = dedent(''' def func(): @@ -7207,6 +7210,7 @@ def test(op, tensor, const, swap_args, template=template): test(op, tensor, const, swap_args) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_tensor_number_math(self): self._test_tensor_number_math() @@ -7643,6 +7647,7 @@ def foo(x: Any): with self.assertRaises(Exception): foo(2) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") def test_isinstance(self): # test isinstance operator for static type checking template = dedent(''' From e6e102cf85d7b8c621d88e2835a5094014eb32e8 Mon Sep 17 00:00:00 2001 From: rzou Date: Fri, 14 Jun 2024 12:30:05 -0700 Subject: [PATCH 024/171] Dynamo testing: add some skips (#128734) The following tests are failing consistently for me locally, so we're going to skip them. They're disabled in CI but it looks like they're just always failing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128734 Approved by: https://github.com/williamwen42 ghstack dependencies: #128731 --- .../TestTorchTidyProfiler.test_optimizer_parameters_sgd | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/dynamo_skips/TestTorchTidyProfiler.test_optimizer_parameters_sgd diff --git a/test/dynamo_skips/TestTorchTidyProfiler.test_optimizer_parameters_sgd b/test/dynamo_skips/TestTorchTidyProfiler.test_optimizer_parameters_sgd new file mode 100644 index 00000000000000..e69de29bb2d1d6 From e9a29aaa4a39403db46219d199fee0163afc2114 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 14 Jun 2024 21:20:42 +0000 Subject: [PATCH 025/171] [ONNX] Add upsample trilinear to skip decomp (#128259) (1) Add upsample trilinear vec to skip decomposition (2) Add tests to make sure that torch.export.export still decomposes them Pull Request resolved: https://github.com/pytorch/pytorch/pull/128259 Approved by: https://github.com/justinchuby --- .ci/docker/common/install_onnx.sh | 4 +- test/onnx/test_fx_op_consistency.py | 6 +-- test/onnx/test_fx_to_onnx.py | 10 ++--- test/onnx/test_fx_to_onnx_decomp_skip.py | 43 +++++++++++++++++-- torch/onnx/_internal/fx/decomposition_skip.py | 39 +++++++++++++++++ 5 files changed, 87 insertions(+), 15 deletions(-) diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index a91c798fcdf28d..1d384233163d8d 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -33,7 +33,9 @@ pip_install coloredlogs packaging pip_install onnxruntime==1.18 pip_install onnx==1.16.0 # pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240523 --no-deps +pip_install onnxscript==0.1.0.dev20240613 --no-deps +# required by onnxscript +pip_install ml_dtypes # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py index 6d675d4460309b..83a4e3965f8137 100644 --- a/test/onnx/test_fx_op_consistency.py +++ b/test/onnx/test_fx_op_consistency.py @@ -527,8 +527,7 @@ def skip_torchlib_forward_compatibility( ), xfail( "gather", - reason="HandleNegativeAxis(int64_t, int64_t) IsAxisInRange(axis, tensor_rank) was \ - false. axis 0 is not in valid range [-0,-1]" + reason="GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'" ), xfail( "geometric", @@ -1517,7 +1516,6 @@ def skip_torchlib_forward_compatibility( "nn.functional.batch_norm", matcher=lambda sample: sample.kwargs.get("training") is True and any(arg is not None for arg in sample.args[2:4]), - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", ), xfail( @@ -1998,7 +1996,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): "nn.functional.hardsigmoid": [1e-3, 5e-3], "nn.functional.hardswish": [1e-3, 5e-3], "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], - "nn.functional.huber_loss": [1e-3, 1e-2], + "nn.functional.huber_loss": [1e-2, 1e-1], "nn.functional.instance_norm": [1e-2, 1e-3], "nn.functional.interpolate": [1e-2, 1e-3], "nn.functional.kl_div": [2e-3, 2e-4], diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 61cb9e807f7076..0f8741c8e0c91e 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -171,13 +171,9 @@ def forward(self, input): torch.argmax(input, dim=1, keepdim=True), ) - # NOTE: KeyError: dim raised in optimizer - with self.assertWarnsOnceRegex( - UserWarning, "ONNXScript optimizer failed. Skipping optimization." - ): - _ = dynamo_export( - ArgminArgmaxModel(), model_input, export_options=self.export_options - ) + _ = dynamo_export( + ArgminArgmaxModel(), model_input, export_options=self.export_options + ) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): diff --git a/test/onnx/test_fx_to_onnx_decomp_skip.py b/test/onnx/test_fx_to_onnx_decomp_skip.py index 16780e604d0717..2f029418f1e98a 100644 --- a/test/onnx/test_fx_to_onnx_decomp_skip.py +++ b/test/onnx/test_fx_to_onnx_decomp_skip.py @@ -18,6 +18,12 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str): class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): + def _test_exported_program_forces_decomposition(self, model, input, op_type): + ep = torch.export.export(model, input) + onnx_program = torch.onnx.dynamo_export(ep, *input) + with self.assertRaises(AssertionError): + assert_op_in_onnx_model(onnx_program.model_proto, op_type) + def test_upsample_bilinear2d(self): class TestModel(torch.nn.Module): def __init__(self): @@ -30,6 +36,9 @@ def forward(self, x): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") + self._test_exported_program_forces_decomposition( + TestModel(), (torch.randn(1, 1, 2, 2),), "Resize" + ) def test_upsample_bilinear2d_output_size(self): def func(x: torch.Tensor): @@ -39,14 +48,42 @@ def func(x: torch.Tensor): # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - def test_instance_norm(self): + def test_upsample_trilinear3d(self): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.upsample = torch.nn.Upsample(scale_factor=2, mode="trilinear") + + def forward(self, x): + return self.upsample(x) + + onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3)) + # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. + assert_op_in_onnx_model(onnx_program.model_proto, "Resize") + self._test_exported_program_forces_decomposition( + TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize" + ) + + def test_upsample_trilinear3d_output_size(self): def func(x: torch.Tensor): - return torch.nn.functional.instance_norm(x) + return torch.nn.functional.interpolate(x, size=(4, 4, 4), mode="trilinear") - onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2)) + onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2, 3)) + # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. + assert_op_in_onnx_model(onnx_program.model_proto, "Resize") + + def test_instance_norm(self): + class TestModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.instance_norm(x) + + onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2)) # If decomposition is skipped, the model will contain an InstanceNormalization op # instead of BatchNormalization op w/ training=True. assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization") + self._test_exported_program_forces_decomposition( + TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization" + ) if __name__ == "__main__": diff --git a/torch/onnx/_internal/fx/decomposition_skip.py b/torch/onnx/_internal/fx/decomposition_skip.py index 646e0765f1907c..cadceecfd146a9 100644 --- a/torch/onnx/_internal/fx/decomposition_skip.py +++ b/torch/onnx/_internal/fx/decomposition_skip.py @@ -121,6 +121,44 @@ def abstract(cls, input, output_size, align_corners, scale_factors): ) +class UpsampleTrilinear3DDecompSkip(DecompSkip): + op_callable = torch._C._nn.upsample_trilinear3d # type: ignore[attr-defined] + onnxscript_function = torchlib_nn.aten_upsample_trilinear3d_vec # type: ignore[attr-defined] + new_op_name = "upsample_trilinear3d" + new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)" + + @classmethod + def register(cls, export_options: torch.onnx.ExportOptions): + if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr( + torch.ops.onnx_export, cls.new_op_name + ): + cls.register_custom_op() + torch._C._nn.upsample_trilinear3d = torch.ops.onnx_export.upsample_trilinear3d # type: ignore[attr-defined] + if export_options.onnx_registry is None: + export_options.onnx_registry = torch.onnx.OnnxRegistry() + registry = export_options.onnx_registry + registry.register_op( + function=cls.onnxscript_function, + namespace=_NEW_OP_NAMESPACE, + op_name=cls.new_op_name, + ) + + @classmethod + def unregister(cls): + torch._C._nn.upsample_trilinear3d = cls.op_callable # type: ignore[attr-defined] + + @classmethod + def abstract(cls, input, output_size, align_corners, scale_factors): + osize = decompositions.upsample_compute_output_size( + input.size(), output_size, scale_factors + ) + return torch.empty( + (input.size(0), input.size(1), input.size(2), *osize), + dtype=input.dtype, + device=input.device, + ) + + class InstanceNormDecompSkip(DecompSkip): op_callable = torch.instance_norm # type: ignore[attr-defined] onnxscript_function = torchlib_core.aten_instance_norm # type: ignore[attr-defined] @@ -176,6 +214,7 @@ def abstract( _DEFAULT_SKIP_LIST = [ UpsampleBilinear2DDecompSkip, InstanceNormDecompSkip, + UpsampleTrilinear3DDecompSkip, ] From 65d3ddcb8bd08d3099cd0a976141e29cfefb0240 Mon Sep 17 00:00:00 2001 From: ibartol Date: Fri, 14 Jun 2024 21:24:53 +0000 Subject: [PATCH 026/171] Add GLIBC requirements for libtorch to solve #113124 (#128135) Fixes #113124. ## Description I modified the installing.rst file to address the system requirements and troubleshooting steps for using LibTorch with different GLIBC versions. ### Summary of Changes - Added system requirements specifying the GLIBC version needed for both the cxx11 ABI version and the pre-cxx11 ABI version of LibTorch. - Included a troubleshooting section with instructions on how to check the dependencies of the LibTorch libraries and identify the required GLIBC version using the `ldd lib/libtorch.so` command. ## Checklist - [X] The issue that is being fixed is referred in the description - [X] Only one issue is addressed in this pull request - [X] Labels from the issue that this PR is fixing are added to this pull request - [X] No unnecesary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/128135 Approved by: https://github.com/jbschlosser --- docs/cpp/source/installing.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/cpp/source/installing.rst b/docs/cpp/source/installing.rst index ea140df7b00fea..ed80b620f93e1f 100644 --- a/docs/cpp/source/installing.rst +++ b/docs/cpp/source/installing.rst @@ -154,6 +154,19 @@ should now merrily print the tensor (exact output subject to randomness): Also, make sure you specify the correct configuration in the ``cmake --build .`` line above. +System Requirements +------------------- + +To ensure smooth installation and usage of LibTorch, please ensure your system +meets the following requirements: + +1. **GLIBC Version**: + - GLIBC 2.29 or newer for cxx11 ABI version + - GLIBC 2.17 or newer for pre-cxx11 ABI version + +2. **GCC Version**: + - GCC 9 or newer for cxx11 and pre-cxx11 ABI versions + Visual Studio Extension ----------------------- From e9c6e8369c588580b0c49d083d1d2bfde8dcd418 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 14 Jun 2024 21:28:17 +0000 Subject: [PATCH 027/171] Torchbind call method + effects support (#128397) Adds effect token support to torchbind method calls by allowing `with_effects` to take in `torch.ops._higher_order_ops.call_torchbind` as an input. Here is the print from `TORCH_LOGS="aot" python test/export/test_torchbind.py -k test_compile_obj_torchbind_op`: ```python def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2]", arg2_1): # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1266 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) cos: "f32[2]" = torch.ops.aten.cos.default(arg1_1) with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, cos); arg0_1 = cos = None getitem: "f32[0]" = with_effects[0]; with_effects = None # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1267 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1) cos_1: "f32[2]" = torch.ops.aten.cos.default(arg1_1) add: "f32[2]" = torch.ops.aten.add.Tensor(cos_1, 1); cos_1 = None with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, add); getitem = add = None getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1268 in f, code: torch.ops._TorchScriptTesting.queue_pop(tq) with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg2_1); getitem_2 = None getitem_4: "f32[0]" = with_effects_2[0]; with_effects_2 = None # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1269 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) sin: "f32[2]" = torch.ops.aten.sin.default(arg1_1); arg1_1 = None with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, sin); getitem_4 = sin = None getitem_6: "f32[0]" = with_effects_3[0]; with_effects_3 = None # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1270 in f, code: return tq.pop(), tq.pop() + tq.size(), tq with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop'); getitem_6 = None getitem_8: "f32[0]" = with_effects_4[0] getitem_9: "f32[2]" = with_effects_4[1]; with_effects_4 = None with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop'); getitem_8 = None getitem_10: "f32[0]" = with_effects_5[0] getitem_11: "f32[2]" = with_effects_5[1]; with_effects_5 = None with_effects_6 = torch._higher_order_ops.effects.with_effects(getitem_10, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'size'); getitem_10 = arg2_1 = None getitem_12: "f32[0]" = with_effects_6[0]; with_effects_6 = None add_1: "f32[2]" = torch.ops.aten.add.Tensor(getitem_11, 0); getitem_11 = None return (getitem_12, getitem_9, add_1) ``` In order to support this, this PR makes the following changes: * Adds `FakeScriptObject` to `CustomObjArgument`, which will be put on the `meta["val"]` of nodes representing torchbind objects. * Adds pickle/deepcopy support to FunctionSchema. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128397 Approved by: https://github.com/ydwu4, https://github.com/zou3519 --- test/export/test_torchbind.py | 30 ++++++++------ torch/_export/passes/lift_constants_pass.py | 16 ++++---- torch/_export/serde/serialize.py | 2 +- torch/_higher_order_ops/effects.py | 45 +++++++++++++++------ torch/_higher_order_ops/torchbind.py | 8 ++-- torch/_ops.py | 3 ++ torch/csrc/jit/python/init.cpp | 9 +++++ torch/export/_remove_effect_tokens_pass.py | 36 ++++++++++++++--- torch/export/_trace.py | 4 +- torch/export/graph_signature.py | 3 ++ torch/export/unflatten.py | 4 ++ torch/fx/graph.py | 2 +- torch/fx/proxy.py | 2 +- 13 files changed, 120 insertions(+), 44 deletions(-) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 42c87bf4c10ef0..c313a733f4750f 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -182,10 +182,12 @@ def forward(self, x, n): self.assertExpectedInline( ep.graph_module.code.strip(), """\ -def forward(self, obj_attr, x, n): - call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None - add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None - return (add,)""", +def forward(self, token, obj_attr, x, n): + with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x); token = obj_attr = None + getitem = with_effects[0] + getitem_1 = with_effects[1]; with_effects = None + add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None + return (getitem, add)""", # noqa: B950 ) def test_method_schema(self): @@ -227,10 +229,12 @@ def forward(self, x): self.assertExpectedInline( ep.graph_module.code.strip(), """\ -def forward(self, obj_attr, x): - call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None - add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None - return (add,)""", +def forward(self, token, obj_attr, x): + with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x); token = obj_attr = None + getitem = with_effects[0] + getitem_1 = with_effects[1]; with_effects = None + add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None + return (getitem, add)""", # noqa: B950 ) @parametrize("pre_dispatch", [True, False]) @@ -293,10 +297,12 @@ def forward(self, x, cc): self.assertExpectedInline( ep.graph_module.code.strip(), """\ -def forward(self, x, cc): - call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None - add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None - return (add,)""", +def forward(self, token, x, cc): + with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops.higher_order.call_torchbind, cc, 'add_tensor', x); token = cc = None + getitem = with_effects[0] + getitem_1 = with_effects[1]; with_effects = None + add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None + return (getitem, add)""", # noqa: B950 ) # aot_export_function runs the program twice # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index d9cd62ffc928fc..926a68fe9dc034 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -234,10 +234,12 @@ def lift_constants_pass( elif isinstance(constant_val, FakeScriptObject): class_fqn = constant_val.script_class_name const_placeholder_node.meta["val"] = CustomObjArgument( - constant_fqn, class_fqn + constant_fqn, class_fqn, constant_val ) input_spec_arg = CustomObjArgument( - name=const_placeholder_node.name, class_fqn=class_fqn + name=const_placeholder_node.name, + class_fqn=class_fqn, + fake_val=constant_val, ) else: raise SpecViolationError( @@ -287,17 +289,17 @@ def rewrite_script_object_meta( if "val" not in node.meta: continue - if isinstance(node.meta["val"], torch.ScriptObject): - old_meta = node.meta["val"] + old_meta = node.meta["val"] + + if isinstance(old_meta, torch.ScriptObject): class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined] new_meta = CustomObjArgument(node.name, class_fqn) constants[node.name] = old_meta node.meta["val"] = new_meta - elif isinstance(node.meta["val"], FakeScriptObject): - old_meta = node.meta["val"] # type: ignore[assignment] + elif isinstance(old_meta, FakeScriptObject): class_fqn = old_meta.script_class_name # type: ignore[attr-defined] - new_meta = CustomObjArgument(node.name, class_fqn) + new_meta = CustomObjArgument(node.name, class_fqn, old_meta) constants[node.name] = old_meta node.meta["val"] = new_meta diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 206c8b88cce42c..f4edbb59156f17 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -899,7 +899,7 @@ def serialize_optional_tensor_args(a): return Argument.create( as_custom_obj=CustomObjArgument(custom_obj_name, class_fqn) ) - elif isinstance(arg, torch._ops.OpOverload): + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return Argument.create(as_operator=self.serialize_operator(arg)) else: raise SerializeError(f"Unsupported argument type: {type(arg)}") diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index a8da01fe06ecd4..c5a448836b07cd 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.utils._pytree as pytree @@ -12,19 +12,26 @@ ProxyTorchDispatchMode, track_tensor_tree, ) +from .torchbind import call_torchbind class _EffectType(Enum): ORDERED = "Ordered" -SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = { +OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] + + +SIDE_EFFECTS: Dict[OpType, _EffectType] = { torch.ops.aten._print.default: _EffectType.ORDERED, + call_torchbind: _EffectType.ORDERED, } -def _register_effectful_op(op: torch._ops.OpOverload, effect: _EffectType): - assert isinstance(op, torch._ops.OpOverload) and not has_aliasing(op) +def _register_effectful_op(op: OpType, effect: _EffectType): + assert isinstance( + op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) and not has_aliasing(op) if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect: raise RuntimeError( f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, " @@ -53,11 +60,11 @@ def __init__(self): def __call__( self, token, - op: torch._ops.OpOverload, + op: OpType, *args: Tuple[Any, ...], **kwargs: Dict[str, Any], ) -> Tuple[Any, ...]: - assert isinstance(op, torch._ops.OpOverload) + assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) assert not has_aliasing(op), "Ops with aliasing is not supported" assert has_effects(op, args, kwargs) assert isinstance(kwargs, dict) @@ -67,7 +74,11 @@ def __call__( with_effects = WithEffects() -def has_aliasing(op: torch._ops.OpOverload): +def has_aliasing(op: OpType): + # NOT FOR PUBLIC USE + if isinstance(op, torch._ops.HigherOrderOperator): + return op not in SIDE_EFFECTS + for arg in op._schema.arguments: if arg.alias_info is not None: return True @@ -84,7 +95,7 @@ def has_effects(op, args, kwargs) -> bool: return False return ( - isinstance(op, torch._ops.OpOverload) + isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload)) and not has_aliasing(op) and get_effect_key(op, args, kwargs) is not None ) @@ -163,10 +174,19 @@ def with_effects_proxy( with_effects.fallthrough(DispatchKey.AutogradCUDA) +def _get_schema(op, args) -> torch.FunctionSchema: + if isinstance(op, torch._ops.OpOverload): + return op._schema + elif op == call_torchbind: + return getattr(args[0], args[1]).schema + else: + raise RuntimeError(f"Unable to get schema for op {op}") + + def handle_effects( allow_token_discovery: bool, tokens: Dict[_EffectType, torch.Tensor], - op: torch._ops.OpOverload, + op: OpType, args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Any: @@ -207,14 +227,15 @@ def handle_effects( unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type] ) - if len(op._schema.returns) == 0: + schema = _get_schema(op, unwrapped_args) + if len(schema.returns) == 0: assert unwrapped_outs[0] is None unwrapped_outs = None # type: ignore[assignment] - elif len(op._schema.returns) == 1: + elif len(schema.returns) == 1: assert len(unwrapped_outs) == 1 unwrapped_outs = unwrapped_outs[0] else: - assert len(unwrapped_outs) == len(op._schema.returns) + assert len(unwrapped_outs) == len(schema.returns) # Add the newly created token into the tokens map for a following call to # use this token. diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index 744e559e65d071..e44ba2851ff55e 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -114,6 +114,8 @@ def call_torchbind_fake(mode, *args, **kwargs): @call_torchbind.py_functionalize_impl def call_torchbind_func(ctx, *args, **kwargs): - args = ctx.unwrap_tensors(args) - with ctx.redispatch_to_next(): - return ctx.wrap_tensors(call_torchbind(*args, **kwargs)) + from torch._higher_order_ops.effects import handle_effects + + return handle_effects( + ctx.mode._allow_token_discovery, ctx.mode._tokens, call_torchbind, args, kwargs + ) diff --git a/torch/_ops.py b/torch/_ops.py index ed8c788b8af6f6..bbf911137e661c 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -382,6 +382,9 @@ def wrapper(): def __str__(self): return f"{self.name()}" + # def __repr__(self): + # return f"torch.ops._higher_order_ops.{self._name}" + def name(self): return self._name diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 818f09bee7bc26..1bfc6c94a707f4 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1940,6 +1940,15 @@ void initJITBindings(PyObject* module) { ss << self; return ss.str(); }) + .def(py::pickle( + [](const FunctionSchema& self) { // __getstate__ + std::stringstream ss; + ss << self; + return py::str(ss.str()); + }, + [](py::str schema) { // __setstate__, note: no `self` argument + return parseSchema(schema); + })) .def_property_readonly( "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); py::class_(m, "Argument") diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 20411dc87cce06..b982cbc6bf7be4 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -3,14 +3,23 @@ from typing import List import torch -from torch._higher_order_ops.effects import with_effects +from torch._higher_order_ops.effects import _get_schema, with_effects from .exported_program import ExportedProgram -from .graph_signature import InputKind, InputSpec, OutputKind, OutputSpec, TokenArgument +from .graph_signature import ( + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TokenArgument, +) def _remove_effect_tokens_from_graph_helper( ep, num_tokens, input_token_names, output_token_names ): + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + output_node = None with_effect_nodes: List[torch.fx.Node] = [] for module in ep.graph_module.modules(): @@ -40,7 +49,22 @@ def _remove_effect_tokens_from_graph_helper( # Replace with_effects(token, func, args) with just func(args) for node in reversed(with_effect_nodes): func = node.args[1] - assert isinstance(func, torch._ops.OpOverload) + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + if func == torch.ops.higher_order.call_torchbind: + custom_obj_meta = node.args[2].meta["val"] + assert isinstance(custom_obj_meta, CustomObjArgument) + if custom_obj_meta.fake_val: + custom_obj = custom_obj_meta.fake_val + elif node.args[2].name in inputs_to_lifted_custom_objs: + custom_obj = ep.constants[ + inputs_to_lifted_custom_objs[node.args[2].name] + ] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + schema = _get_schema(func, (custom_obj,) + node.args[3:]) + else: + schema = _get_schema(func, node.args[2:]) with ep.graph.inserting_before(node): new_node = ep.graph.call_function(func, node.args[2:]) @@ -56,7 +80,7 @@ def _remove_effect_tokens_from_graph_helper( if user.args[1] == 0: ep.graph.erase_node(user) - if len(func._schema.returns) == 1: + if len(schema.returns) == 1: # If the function has 1 return then it will just directly return the # result -- we don't need a getitem. So we can replace all the # getitem(with_effects, 1) with just the note itself. @@ -65,7 +89,7 @@ def _remove_effect_tokens_from_graph_helper( user.replace_all_uses_with(new_node) new_node.meta["val"] = node.meta["val"][1] - elif len(func._schema.returns) > 1: + elif len(schema.returns) > 1: # If the function has more than 1 return then since we got rid of # the 1st return value (the token), we need to bump all the other # getitem calls by 1 down @@ -75,7 +99,7 @@ def _remove_effect_tokens_from_graph_helper( new_node.meta["val"] = node.meta["val"][1:] else: - assert len(func._schema.returns) == 0 + assert len(schema.returns) == 0 assert len(new_node.users) == 0 new_node.meta["val"] = None diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 1542fb1423174b..d5f0851c87694e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -617,7 +617,9 @@ def make_argument_spec(i, node) -> ArgumentSpec: elif isinstance(val, torch.ScriptObject): return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined] elif isinstance(val, FakeScriptObject): - return CustomObjArgument(name=node.name, class_fqn=val.script_class_name) + return CustomObjArgument( + name=node.name, class_fqn=val.script_class_name, fake_val=val + ) elif isinstance(val, (int, bool, str, float, type(None))): return ConstantArgument(name=node.name, value=val) else: diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index ce62e879394151..acc9d705c6c1c5 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -3,6 +3,8 @@ from enum import auto, Enum from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union +from torch._library.fake_class_registry import FakeScriptObject + __all__ = [ "ConstantArgument", @@ -37,6 +39,7 @@ class SymIntArgument: class CustomObjArgument: name: str class_fqn: str + fake_val: Optional[FakeScriptObject] = None @dataclasses.dataclass diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 11075058a0e941..7b6ed6f1b5a974 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -23,6 +23,9 @@ from torch.fx._symbolic_trace import is_fx_tracing from torch.utils._pytree import GetAttrKey, SequenceKey +from ._remove_effect_tokens_pass import _remove_effect_tokens + + __all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"] @@ -485,6 +488,7 @@ def unflatten( An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """ + module = _remove_effect_tokens(module) return UnflattenedModule(module, flat_args_adapter) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 9c4840d2b2a91e..83aeb19b8d1112 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -476,7 +476,7 @@ def _get_repr(arg: Any) -> str: qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" - elif isinstance(arg, torch._ops.OpOverload): + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): qualified_name = _get_qualified_name(arg) global_name = add_global(qualified_name, arg) return f"{global_name}" diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index f732b21080ddb4..649ba4cb648756 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -282,7 +282,7 @@ def no_node(arg): elif isinstance(a, range): return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) - elif isinstance(a, torch._ops.OpOverload): + elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return a if isinstance(a, Proxy): From f103247a14233e0dfdd311536830e63e025093f5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Fri, 14 Jun 2024 17:41:18 +0000 Subject: [PATCH 028/171] Run all samples for torchinductor tests (#128343) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128343 Approved by: https://github.com/lezcano --- test/inductor/test_torchinductor_opinfo.py | 179 ++++++++++++++++----- 1 file changed, 135 insertions(+), 44 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 1d9c733a73029f..7998a3aff58d6a 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -406,44 +406,132 @@ def wrapper_noop_set_seed(op, *args, **kwargs): } -# Always test with all sample for following ops -inductor_all_samples = { - "arange", - "diagonal", - "diagonal_copy", - "diagonal_scatter", - "softmax.with_dtype", - "index_add", - "index_copy", - "index_reduce.prod", - "index_reduce.mean", - "index_reduce.amax", - "index_reduce.amin", - "scatter_reduce.sum", - "select_scatter", - "squeeze", - "unfold", - "unsqueeze", - "sum", - "amax", - "amin", - "all", - "T", - "H", - "isinf", - "isposinf", - "isneginf", - "nan_to_num", - "mT", - "mH", - "rsub", - "triu", - "cummax", - "cummin", - "nextafter", - "gather", - "_chunk_cat", - "constant_pad_nd", +# Test with one sample only for following ops +inductor_one_sample = { + "_segment_reduce.lengths": {f16}, + "_segment_reduce.offsets": {f16}, + "addmv": {f16}, + "argsort": {b8, f16, f32, f64, i32, i64}, + "as_strided.partial_views": {f16}, + "clamp_max": {b8}, + "clamp_min": {b8}, + "corrcoef": {f16}, + "diff": {f16}, + "einsum": {f16, i32}, + "gradient": {f16}, + "histogram": {f32, f64}, + "histogramdd": {f32, f64}, + "index_put": {f16, f32, f64}, + "linalg.eig": {f32, f64}, + "linspace": {f16, i32, i64}, + "linspace.tensor_overload": {f16, f32, f64, i32, i64}, + "logspace": {f16}, + "logspace.tensor_overload": {f16, f32, f64, i32, i64}, + "masked_logsumexp": {i64}, + "max.binary": {b8}, + "max_pool2d_with_indices_backward": {f16, f32, f64}, + "maximum": {b8}, + "min.binary": {b8}, + "minimum": {b8}, + "ne": {b8}, + "new_empty_strided": {f16}, + "nn.functional.adaptive_avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool1d": {f16, f32}, + "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.bilinear": {f16}, + "nn.functional.conv_transpose1d": {f16}, + "nn.functional.conv_transpose2d": {f16}, + "nn.functional.conv_transpose3d": {f16}, + "nn.functional.cosine_similarity": {f16}, + "nn.functional.cross_entropy": {f16, f32, f64}, + "nn.functional.gaussian_nll_loss": {f16}, + "nn.functional.grid_sample": {f32, f64}, + "nn.functional.interpolate.area": {f16}, + "nn.functional.max_pool2d": {f16, f32, f64, i32, i64}, + "nn.functional.nll_loss": {f16, f32, f64}, + "normal": {f16, f32, f64}, + "put": {f16, f32, f64}, + "rot90": {b8, f16, f32, f64, i32, i64}, + "scatter": {b8, i64}, + "take": {b8, f16, f32, f64, i32, i64}, + ("__rdiv__", "cuda"): {f16}, + ("__rmod__", "cuda"): {f16, i64}, + ("__rmul__", "cuda"): {f16}, + ("__rpow__", "cuda"): {f16}, + ("addcdiv", "cuda"): {f16}, + ("addcmul", "cuda"): {f16}, + ("atan2", "cuda"): {f16}, + ("cumsum", "cuda"): {f16}, + ("cumulative_trapezoid", "cuda"): {f16}, + ("dist", "cuda"): {f16}, + ("div.no_rounding_mode", "cuda"): {f16}, + ("fmod", "cuda"): {f16}, + ("grid_sampler_2d", "cuda"): {f16}, + ("index_fill", "cuda"): {f16, f32, f64}, + ("ldexp", "cuda"): {f16}, + ("lerp", "cuda"): {f16}, + ("linalg.householder_product", "cuda"): {f32}, + ("linalg.matrix_norm", "cuda"): {f16}, + ("linalg.vector_norm", "cuda"): {f16}, + ("logspace", "cuda"): {i32, i64}, + ("masked.cumsum", "cuda"): {f16}, + ("masked.logsumexp", "cuda"): {f16}, + ("masked.mean", "cuda"): {b8}, + ("masked.normalize", "cuda"): {f16}, + ("masked.prod", "cuda"): {f16}, + ("masked.std", "cuda"): {f16}, + ("masked.var", "cuda"): {f16}, + ("mul", "cuda"): {f16}, + ("nn.functional.alpha_dropout", "cuda"): {f16, f32, f64}, + ("nn.functional.avg_pool1d", "cuda"): {f16, f32, f64}, + ("nn.functional.avg_pool2d", "cuda"): {f16, f32, f64}, + ("nn.functional.avg_pool3d", "cuda"): {f16, f32, f64}, + ("nn.functional.binary_cross_entropy", "cuda"): {f16}, + ("nn.functional.binary_cross_entropy_with_logits", "cuda"): {f16}, + ("nn.functional.conv2d", "cuda"): {f16}, + ("nn.functional.cosine_embedding_loss", "cuda"): {f16}, + ("nn.functional.dropout2d", "cuda"): {f16, f32, f64}, + ("nn.functional.dropout3d", "cuda"): {f16, f32, f64}, + ("nn.functional.dropout", "cuda"): {f16, f32, f64}, + ("nn.functional.feature_alpha_dropout.with_train", "cuda"): {f16, f32, f64}, + ("nn.functional.fractional_max_pool2d", "cuda"): {f16, f32, f64}, + ("nn.functional.fractional_max_pool3d", "cuda"): {f16, f32, f64}, + ("nn.functional.grid_sample", "cuda"): {f16}, + ("nn.functional.group_norm", "cuda"): {f16}, + ("nn.functional.hinge_embedding_loss", "cuda"): {f16}, + ("nn.functional.interpolate.bicubic", "cuda"): {f16}, + ("nn.functional.interpolate.bilinear", "cuda"): {f16}, + ("nn.functional.interpolate.trilinear", "cuda"): {f16}, + ("nn.functional.kl_div", "cuda"): {f16}, + ("nn.functional.margin_ranking_loss", "cuda"): {f16}, + ("nn.functional.max_pool1d", "cuda"): {f16, f32, f64}, + ("nn.functional.max_pool3d", "cuda"): {f16}, + ("nn.functional.mse_loss", "cuda"): {f16}, + ("nn.functional.multi_margin_loss", "cuda"): {f16}, + ("nn.functional.multilabel_margin_loss", "cuda"): {f16}, + ("nn.functional.multilabel_soft_margin_loss", "cuda"): {f16}, + ("nn.functional.normalize", "cuda"): {f16}, + ("nn.functional.pad.replicate", "cuda"): {f16, f32, f64}, + ("nn.functional.pad.reflect", "cuda"): {f16}, + ("nn.functional.pairwise_distance", "cuda"): {f16}, + ("nn.functional.poisson_nll_loss", "cuda"): {f16}, + ("nn.functional.rms_norm", "cuda"): {f16}, + ("norm", "cuda"): {f16}, + ("pow", "cuda"): {f16}, + ("prod", "cuda"): {f16}, + ("scatter_reduce.amax", "cuda"): {f16, f32, f64}, + ("scatter_reduce.amin", "cuda"): {f16, f32, f64}, + ("scatter_reduce.mean", "cuda"): {f16, f32, f64}, + ("special.xlog1py", "cuda"): {f16}, + ("std", "cuda"): {f16}, + ("std_mean", "cuda"): {f16}, + ("svd_lowrank", "cuda"): {f32, f64}, + ("trapezoid", "cuda"): {f16}, + ("trapz", "cuda"): {f16}, + ("true_divide", "cuda"): {f16}, + ("var", "cuda"): {f16}, + ("var_mean", "cuda"): {f16}, + ("xlogy", "cuda"): {f16}, } @@ -489,10 +577,14 @@ def tearDown(self): ) @collection_decorator def test_comprehensive(self, device, dtype, op): + device_type = torch.device(device).type + + assert device_type in (GPU_TYPE, "cpu") + torch._dynamo.reset() with torch.no_grad(): # TODO: should we move empty_cache to the common device interface - if device == "cuda": + if device_type == "cuda": torch.cuda.empty_cache() op_name = op.name if op.variant_test_name: @@ -509,10 +601,6 @@ def test_comprehensive(self, device, dtype, op): if dtype not in allowed_dtypes: raise unittest.SkipTest("Skipped!") - device_type = torch.device(device).type - - assert device_type in (GPU_TYPE, "cpu") - # with open("test_output.txt", "a") as f: # print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} | # {inductor_skips[device_type].get(op_name, set())}", flush=True, file=f) @@ -557,7 +645,10 @@ def fn(*args, **kwargs): ) samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) - if op_name not in inductor_all_samples and not ALL_SAMPLES: + if ( + dtype in inductor_one_sample.get(op_name, {}) + or dtype in inductor_one_sample.get((op_name, device_type), {}) + ) and not ALL_SAMPLES: if isinstance(samples, (list, tuple)): samples = [samples[0]] else: From bca2cf00edcaaedb3f02cd66deaccfd8947e2e74 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 14 Jun 2024 21:56:49 +0000 Subject: [PATCH 029/171] [ONNX] Add dynamic axes support to torchscript exporter with dynamo=True (#128371) This PR enables specific axe to be dynamic with calling torch.export.export and torch.export.Dim. Features: (1) Turn dynamic_axes to dynamic_shapes (2) Dim constraints remain the same (see test case with hitting constraints). This might give different user experience, since we didn't have any constraints in torchscript-onnx exporting. (3) If input_names is used in dynamic_axes, ValueError will be raised, as input_names is currently not supported. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128371 Approved by: https://github.com/justinchuby --- test/onnx/dynamo/test_exporter_api.py | 195 +++++++++++++++++++++++--- torch/onnx/utils.py | 82 +++++++++-- 2 files changed, 246 insertions(+), 31 deletions(-) diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index 30bfd27483b996..18340165560a21 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -33,6 +33,11 @@ def forward(self, x, b): return (y, z) +class SampleModelForDynamicShapes(torch.nn.Module): + def forward(self, x, b): + return x.relu(), b.sigmoid() + + class _LargeModel(torch.nn.Module): def __init__(self): super().__init__() @@ -230,8 +235,15 @@ def test_serialize_succeeds_when_model_greater_than_2gb(self): class TestONNXExportWithDynamo(common_utils.TestCase): def test_args_normalization_with_no_kwargs(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), + ( + torch.randn(1, 1, 2), + torch.randn(1, 1, 2), + ), + ) onnx_program_from_new_exporter = torch.onnx.dynamo_export( - SampleModelTwoInputs(), torch.randn(1, 1, 2), torch.randn(1, 1, 2) + exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2) ) onnx_program_from_old_exporter = torch.onnx.export( SampleModelTwoInputs(), @@ -243,9 +255,25 @@ def test_args_normalization_with_no_kwargs(self): onnx_program_from_old_exporter.model_proto, ) + def test_args_is_tensor_not_tuple(self): + exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),)) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, torch.randn(1, 1, 2) + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModel(), torch.randn(1, 1, 2), dynamo=True + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + def test_args_normalization_with_kwargs(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} + ) onnx_program_from_new_exporter = torch.onnx.dynamo_export( - SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) ) onnx_program_from_old_exporter = torch.onnx.export( SampleModelTwoInputs(), @@ -258,8 +286,11 @@ def test_args_normalization_with_kwargs(self): ) def test_args_normalization_with_empty_dict_at_the_tail(self): + exported_program = torch.export.export( + SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)} + ) onnx_program_from_new_exporter = torch.onnx.dynamo_export( - SampleModelTwoInputs(), torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) + exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2) ) onnx_program_from_old_exporter = torch.onnx.export( SampleModelTwoInputs(), @@ -271,17 +302,111 @@ def test_args_normalization_with_empty_dict_at_the_tail(self): onnx_program_from_old_exporter.model_proto, ) - def test_dynamic_axes_enable_dynamic_shape(self): + def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": { + 0: torch.export.Dim("customx_dim_0"), + 1: torch.export.Dim("customx_dim_1"), + 2: torch.export.Dim("customx_dim_2"), + }, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) onnx_program_from_new_exporter = torch.onnx.dynamo_export( - SampleModelTwoInputs(), - torch.randn(1, 1, 2), - b=torch.randn(1, 1, 2), - export_options=ExportOptions(dynamic_shapes=True), + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), ) onnx_program_from_old_exporter = torch.onnx.export( - SampleModelTwoInputs(), - (torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}, {}), - dynamic_axes={"b": [0, 1, 2]}, + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}), + dynamic_axes={ + "x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"}, + "b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"}, + }, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": { + 0: torch.export.Dim("customx_dim_0"), + 1: torch.export.Dim("customx_dim_1"), + 2: torch.export.Dim("customx_dim_2"), + }, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}), + dynamic_axes={ + "x": [0, 1, 2], + "b": [0, 1, 2], + }, + dynamo=True, + ) + self.assertEqual( + onnx_program_from_new_exporter.model_proto, + onnx_program_from_old_exporter.model_proto, + ) + + def test_dynamic_axes_supports_partial_dynamic_shapes(self): + exported_program = torch.export.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + dynamic_shapes={ + "x": None, + "b": { + 0: torch.export.Dim("customb_dim_0"), + 1: torch.export.Dim("customb_dim_1"), + 2: torch.export.Dim("customb_dim_2"), + }, + }, + ) + onnx_program_from_new_exporter = torch.onnx.dynamo_export( + exported_program, + torch.randn(2, 2, 3), + b=torch.randn(2, 2, 3), + ) + onnx_program_from_old_exporter = torch.onnx.export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}, {}), + dynamic_axes={ + "b": [0, 1, 2], + }, dynamo=True, ) self.assertEqual( @@ -303,16 +428,37 @@ def test_raises_unrelated_parameters_warning(self): dynamo=True, ) - def test_raises_unsupported_specific_dynamic_axes_warning(self): - message = ( - "Specified dynamic axes is not supported for dynamo export at the moment." - ) + def test_input_names_are_not_yet_supported_in_dynamic_axes(self): + with self.assertRaisesRegex( + ValueError, + "Assinging new input names is not supported yet. Please use model forward signature " + "to specify input names in dynamix_axes.", + ): + _ = torch.onnx.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + input_names=["input"], + dynamic_axes={"input": [0, 1]}, + dynamo=True, + ) - with self.assertWarnsOnceRegex(UserWarning, message): + def test_dynamic_shapes_hit_constraints_in_dynamo(self): + # SampleModelTwoInputs has constraints becuse of add of two inputs, + # so the two input shapes are related. + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + "Constraints violated", + ): _ = torch.onnx.export( - SampleModel(), - (torch.randn(1, 1, 2),), - dynamic_axes={"input": [0, 1, 2]}, + SampleModelTwoInputs(), + (torch.randn(2, 2, 3), torch.randn(2, 2, 3)), + dynamic_axes={ + "x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"}, + "b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"}, + }, dynamo=True, ) @@ -323,6 +469,17 @@ def test_saved_f_exists_after_export(self): ) self.assertTrue(os.path.exists(path)) + def test_raises_error_when_input_is_script_module(self): + class ScriptModule(torch.jit.ScriptModule): + def forward(self, x): + return x + + with self.assertRaisesRegex( + TypeError, + "Dynamo export does not support ScriptModule or ScriptFunction.", + ): + _ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 94a57786a4bd77..0d02fabd1beb5b 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -512,6 +512,10 @@ def forward(self, x): """ if dynamo: + if isinstance(model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)): + raise TypeError( + "Dynamo export does not support ScriptModule or ScriptFunction." + ) # Unsupported parameters for dynamo export # TODO: These are not supported AT THE TIME warnings.warn( @@ -519,7 +523,6 @@ def forward(self, x): "do_constant_folding, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, and " "autograd_inlining are not supported for dynamo export at the moment." ) - # TODO: check args normalization args = _decide_input_format(model, args) kwargs = {} if args is not None and isinstance(args[-1], dict): @@ -527,18 +530,14 @@ def forward(self, x): args = args[:-1] # TODO: refactor this when we have migrated ExportedProgam and # needs users to specify dynamic_axes - if dynamic_axes is None or not isinstance(dynamic_axes, dict): - dynamic_shapes = False - else: - dynamic_shapes = True - warnings.warn( - "Specified dynamic axes is not supported for dynamo export at the moment." - ) - # TODO: expose more ExportOptions? - export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_shapes) - onnx_program = torch.onnx.dynamo_export( - model, *args, **kwargs, export_options=export_options + dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( + model, dynamic_axes, input_names ) + exported_program = torch.export.export( + model, args=args, kwargs=kwargs, dynamic_shapes=dynamic_shapes # type: ignore[arg-type] + ) + # TODO: expose ExportOptions? + onnx_program = torch.onnx.dynamo_export(exported_program, *args, **kwargs) if f is not None: onnx_program.save(f) return onnx_program @@ -915,6 +914,65 @@ def _decide_input_format(model, args): return args +@_beartype.beartype +def _from_dynamic_axes_to_dynamic_shapes( + model, + dynamic_axes: Optional[ + Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]] + ] = None, + input_names: Optional[Sequence[str]] = None, +) -> Optional[Dict[str, Any]]: + """ + + dynamic_axes examples: + (1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}} + (2) dynamic_axes = {"x": [0], "y": [1]} + + these will be converted to dynamic_shapes respectively: + (1) dynamic_shapes = {"x": {0: Dim("my_custom_axis_name_1")}, "y": {1: Dim("my_custom_axis_name_2")}} + (2) dynamic_shapes = {"x": {0: Dim("x_dim_0")}, "y": {1: Dim("y_dim_1")}} # auto-generated dim names + + """ + if dynamic_axes is None: + return None + + if input_names is None: + input_names_set = set() + else: + input_names_set = set(input_names) + + dynamic_shapes: Dict[str, Optional[Any]] = {} + for input_name, axes in dynamic_axes.items(): + if input_name in input_names_set: + raise ValueError( + "Assinging new input names is not supported yet. Please use model forward signature " + "to specify input names in dynamix_axes." + ) + if isinstance(axes, dict): + dynamic_shapes[input_name] = { + k: torch.export.Dim(v) for k, v in axes.items() + } + elif isinstance(axes, list): + dynamic_shapes[input_name] = { + k: torch.export.Dim(f"{input_name}_dim_{k}") for k in axes + } + else: + raise TypeError( + f"dynamic_axes value must be either a dict or a list, but got {type(axes)}" + ) + # torch.export.export needs static dim to present in dynamic_shapes + # for all input tensors, so we need to add them with None + try: + sig = _signature(model) + except ValueError as e: + warnings.warn(f"{e}, skipping auto filling None on static axes...") + return dynamic_shapes + for input_name in sig.parameters.keys(): + if input_name not in dynamic_shapes: + dynamic_shapes[input_name] = None + return dynamic_shapes + + @_beartype.beartype def _trace(func, args, operator_export_type, return_outs=False): # Special case for common case of passing a single Tensor From d3a4d9e4feea238b33b56c8f60d8e0a4dabefee4 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 14 Jun 2024 22:23:00 +0000 Subject: [PATCH 030/171] Update cu124 dynamo benchmark expected values (#128737) Missed one in https://github.com/pytorch/pytorch/pull/128589 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128737 Approved by: https://github.com/Skylion007 --- .../ci_expected_accuracy/cu124/inductor_torchbench_training.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv index cfc52442664401..a62389fad7dadf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,52 +hf_BigBird,pass,49 From fd27138c4a86bd763a6b8128d940a7c98f951603 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 14 Jun 2024 22:54:21 +0000 Subject: [PATCH 031/171] Update DALLE2_pytorch expected accuracy result on CPU (#128718) I suspect that the issue shows up because of the new version of https://pypi.org/project/pyarrow/16.1.0/#history released yesterday. The package is a dependency of DALLE2_pytorch https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/DALLE2_pytorch/install.py#L22. I'll just update the expected accuracy result on CPU benchmark because the model fails to run there anyway. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128718 Approved by: https://github.com/malfet --- .../cpu_inductor_torchbench_freezing_inference.csv | 2 +- .../ci_expected_accuracy/cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_cpu_inductor_torchbench_inference.csv | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv index 3af215541c1dfb..73c46e578eec2b 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 +DALLE2_pytorch,eager_fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index a497fb45d7d48c..f2aafef9db9fa3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 +DALLE2_pytorch,eager_fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index 5ffc870a8dec1f..da37b6c9c9e809 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 +DALLE2_pytorch,eager_fail_to_run,0 From 4abecd7102811e232496ed392b0ff432013398db Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Fri, 14 Jun 2024 02:06:01 -0700 Subject: [PATCH 032/171] [AOTI] fixed performance issue for AOTI_TORCH_CHECK (#128402) We introduced AOTI_TORCH_CHECK in #119220 to resolve slow-compilation time issues. Unfortunately, it caused perf regressions for CPU , as described in issue #126665. After some investigation, it turned out the slow compilation was caused by the use of the builtin function __builtin_expect provided by gcc/clang. Moreover, nuking __builtin_expect doesn't seem to cause any performance penalty, even though its purpose is to improve performance by providing the compiler with branch prediction information. abs latency numbers using the script shared by #126665: before the fix after the fix T5Small 1019.055694 917.875027 T5ForConditionalGeneration 1009.825196 916.369239 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128402 Approved by: https://github.com/desertfire --- torch/csrc/inductor/aoti_torch/c/shim.h | 32 ++++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 65fbbd9fc23d6f..5450293afb97e9 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -566,21 +566,25 @@ AOTI_TORCH_EXPORT void aoti_torch_check( const char* msg); #ifdef STRIP_ERROR_MESSAGES -#define AOTI_TORCH_CHECK(cond, ...) \ - aoti_torch_check( \ - cond, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } #else -#define AOTI_TORCH_CHECK(cond, ...) \ - aoti_torch_check( \ - cond, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } #endif #ifdef __cplusplus From 52d4442a00ca0c41a1f8abd0b48c2d71b07fba0b Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 14 Jun 2024 23:08:27 +0000 Subject: [PATCH 033/171] [c10d] Socket, TCPStore: add better logging (#128673) This adds better logging of errors to the socket and TCPStore classes. All socket operations should now include the local and remote addresses and we actually log errors from the TCPStoreBackend::run as well as TCPStoreBackendUV which were previously INFO messages and not actually logged. It also overhauls test_wait in test_store.py as it had a race condition causing it to be flaky. Test plan: ``` python test/distributed/test_store.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128673 Approved by: https://github.com/c-p-i-o --- test/distributed/test_store.py | 61 ++++++---- torch/csrc/distributed/c10d/TCPStore.cpp | 21 +++- .../csrc/distributed/c10d/TCPStoreBackend.cpp | 108 ++++++++++-------- .../distributed/c10d/TCPStoreLibUvBackend.cpp | 12 +- torch/csrc/distributed/c10d/socket.cpp | 79 +++++++++++-- torch/csrc/distributed/c10d/socket.h | 2 + 6 files changed, 194 insertions(+), 89 deletions(-) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index cd126cc0d3580a..b426347ebedef1 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -16,10 +16,7 @@ import torch.distributed.rpc as rpc from torch.distributed import DistError, DistNetworkError, DistStoreError from torch.testing._internal.common_distributed import MultiThreadedTestCase -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) +from torch.testing._internal.common_utils import instantiate_parametrized_tests if not dist.is_available(): print("torch.distributed not available, skipping tests", file=sys.stderr) @@ -841,19 +838,11 @@ def test_extended_methods_fallbacks(self): class TestMultiThreadedWait(MultiThreadedTestCase): - # TODO (xilunwu): Use less hacky means of instantiating stores. - # Note, stores accumulate values per test. - stores = [ - dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1), - dist.HashStore(), - dist.PrefixStore( - "pre", dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) - ), - create_tcp_store(use_libuv=False), - create_tcp_store(use_libuv=True), - dist.PrefixStore("pre", create_tcp_store(use_libuv=False)), - dist.PrefixStore("pre", create_tcp_store(use_libuv=True)), - ] + file_store = dist.FileStore(tempfile.NamedTemporaryFile(delete=False).name, 1) + hash_store = dist.HashStore() + + tcp_store = create_tcp_store(use_libuv=False) + tcp_store_uv = create_tcp_store(use_libuv=True) @property def world_size(self): @@ -863,10 +852,7 @@ def setUp(self): super().setUp() self._spawn_threads() - # Iterates over self.stores, keep 7 in sync with len(self.stores). - @parametrize("i", range(7)) - def test_wait(self, i): - store = self.stores[i] + def _test_wait(self, store): store.set_timeout(timedelta(seconds=2)) if dist.get_rank() == 0: store.wait(["key1"]) @@ -874,6 +860,39 @@ def test_wait(self, i): if dist.get_rank() == 1: store.set("key1", "value1") + def test_wait_hash_store(self): + self._test_wait(self.hash_store) + + def test_wait_file_store(self): + self._test_wait(self.file_store) + + def test_wait_prefix_file_store(self): + store = dist.PrefixStore("pre", self.file_store) + self._test_wait(store) + + def _test_wait_tcp_store(self, master_store): + store = ( + master_store + if dist.get_rank() == 0 + else dist.TCPStore( + host_name=master_store.host, + port=master_store.port, + is_master=False, + wait_for_workers=False, + use_libuv=False, + ) + ) + self._test_wait(store) + + prefix_store = dist.PrefixStore("pre", store) + self._test_wait(prefix_store) + + def test_wait_tcp_store(self): + self._test_wait_tcp_store(self.tcp_store) + + def test_wait_tcp_store_uv(self): + self._test_wait_tcp_store(self.tcp_store_uv) + instantiate_parametrized_tests(TestMultiThreadedWait) diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index a716bf6667559e..fe24c31f9068bd 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -155,16 +155,31 @@ class TCPClient { const TCPStoreOptions& opts); void sendRaw(uint8_t* data, size_t lenght) { - tcputil::sendBytes(socket_.handle(), data, lenght); + try { + tcputil::sendBytes(socket_.handle(), data, lenght); + } catch (const std::exception& e) { + C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what()); + throw; + } } std::vector receiveBits() { - return tcputil::recvVector(socket_.handle()); + try { + return tcputil::recvVector(socket_.handle()); + } catch (const std::exception& e) { + C10D_WARNING("recvVector failed on {}: {}", socket_.repr(), e.what()); + throw; + } } template T receiveValue() { - return tcputil::recvValue(socket_.handle()); + try { + return tcputil::recvValue(socket_.handle()); + } catch (const std::exception& e) { + C10D_WARNING("recvValue failed on {}: {}", socket_.repr(), e.what()); + throw; + } } template bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) { diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp index 7018c1d65115af..2cb20fee2adc95 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp @@ -17,6 +17,7 @@ #include #include +#include #ifdef _WIN32 #include @@ -546,61 +547,70 @@ void TCPStoreMasterDaemon::run() { } #else void TCPStoreMasterDaemon::run() { - c10::setThreadName("pt_tcpstore"); - - std::vector fds; - tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); - // Although we haven't found any documentation or literature describing this, - // we've seen cases that, under certain circumstances, the read end of the - // pipe won't receive POLLHUP when the write end is closed. However, under - // the same circumstances, writing to the pipe will guarantee POLLIN to be - // received on the read end. - // - // For more reliable termination, the main thread will write a byte to the - // pipe before closing it, and the background thread will poll for both - // POLLIN and POLLHUP. - tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); - - // receive the queries - bool finished = false; - while (!finished) { - for (const auto i : c10::irange(sockets_.size())) { - fds[i].revents = 0; - } + try { + c10::setThreadName("pt_tcpstore"); + + std::vector fds; + tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); + // Although we haven't found any documentation or literature describing + // this, we've seen cases that, under certain circumstances, the read end of + // the pipe won't receive POLLHUP when the write end is closed. However, + // under the same circumstances, writing to the pipe will guarantee POLLIN + // to be received on the read end. + // + // For more reliable termination, the main thread will write a byte to the + // pipe before closing it, and the background thread will poll for both + // POLLIN and POLLHUP. + tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); + + // receive the queries + bool finished = false; + while (!finished) { + for (const auto i : c10::irange(sockets_.size())) { + fds[i].revents = 0; + } - SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); + SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); - // TCPStore's listening socket has an event and it should now be able to - // accept new connections. - if (fds[0].revents != 0) { - if (fds[0].revents ^ POLLIN) { - C10_THROW_ERROR( - DistStoreError, - "Unexpected poll revent on the master's listening socket: " + - std::to_string(fds[0].revents)); + // TCPStore's listening socket has an event and it should now be able to + // accept new connections. + if (fds[0].revents != 0) { + if (fds[0].revents ^ POLLIN) { + C10_THROW_ERROR( + DistStoreError, + "Unexpected poll revent on the master's listening socket: " + + std::to_string(fds[0].revents)); + } + Socket socket = storeListenSocket_.accept(); + int rawSocket = socket.handle(); + sockets_.emplace_back(std::move(socket)); + tcputil::addPollfd(fds, rawSocket, POLLIN); + // all clients are miscellaneous before getting its validation query + addMiscellaneousSocket(rawSocket); } - Socket socket = storeListenSocket_.accept(); - int rawSocket = socket.handle(); - sockets_.emplace_back(std::move(socket)); - tcputil::addPollfd(fds, rawSocket, POLLIN); - // all clients are miscellaneous before getting its validation query - addMiscellaneousSocket(rawSocket); - } - // The pipe receives an event which tells us to shutdown the daemon - if (fds[1].revents != 0) { - // The main thread will write a byte to the pipe then close it before - // joining the background thread - if (fds[1].revents & ~(POLLIN | POLLHUP)) { - C10_THROW_ERROR( - DistStoreError, - "Unexpected poll revent on the control pipe's reading fd: " + - std::to_string(fds[1].revents)); + // The pipe receives an event which tells us to shutdown the daemon + if (fds[1].revents != 0) { + // The main thread will write a byte to the pipe then close it before + // joining the background thread + if (fds[1].revents & ~(POLLIN | POLLHUP)) { + C10_THROW_ERROR( + DistStoreError, + "Unexpected poll revent on the control pipe's reading fd: " + + std::to_string(fds[1].revents)); + } + finished = true; + break; } - finished = true; - break; + queryFds(fds); } - queryFds(fds); + } catch (const std::exception& ex) { + C10D_ERROR( + "TCPStoreMasterDaemon::run() failed with exception: ", ex.what()); + throw; + } catch (...) { + C10D_ERROR("TCPStoreMasterDaemon::run() failed with unknown exception"); + throw; } } #endif diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 8dfc154ba6826d..5890ff5d95ae5e 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -125,7 +125,7 @@ class UvTcpSocket : public UvHandle { try { uv_socket->processBuf(buf, nread); } catch (std::exception& ex) { - C10D_INFO("Error processing client message: {}", ex.what()); + C10D_WARNING("Error processing client message: {}", ex.what()); uv_socket->close(); } } @@ -139,7 +139,7 @@ class UvTcpSocket : public UvHandle { void startRead() { int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback); if (res) { - C10D_INFO( + C10D_WARNING( "Failed to setup read callback. client:{} code:{} name:{} desc:{}.", (void*)this, res, @@ -357,7 +357,7 @@ class WriterPayload : public c10::intrusive_ptr_target { auto handle = wp->handle; if (status) { - C10D_INFO( + C10D_WARNING( "Write to client failed. code:{} name:{} desc:{}.", status, uv_err_name(status), @@ -387,7 +387,7 @@ class WriterPayload : public c10::intrusive_ptr_target { &req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done); if (res) { - C10D_INFO( + C10D_WARNING( "Write setup to client failed. code:{} name:{} desc:{}.", res, uv_err_name(res), @@ -994,7 +994,7 @@ void LibUVStoreDaemon::onConnect(int status) { tcpServer->accept(client); client->startRead(); } catch (std::exception& e) { - C10D_INFO("Failed to accept client due to {}", e.what()); + C10D_WARNING("Failed to accept client due to {}", e.what()); client->close(); } } @@ -1111,7 +1111,7 @@ void LibUVStoreDaemon::run() { void LibUVStoreDaemon::stop() { int res = uv_async_send(&exit_handle); if (res) { - C10D_INFO( + C10D_WARNING( "uv_async_send failed with:{} errn:{} desc:{}\n", res, uv_err_name(res), diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index 6cbaa018762eb9..a7020e24964243 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -17,6 +17,7 @@ #include #include #else +#include #include #include #include @@ -36,6 +37,7 @@ C10_DIAGNOSTIC_POP() #include #include +#include namespace c10d::detail { namespace { @@ -136,7 +138,10 @@ class SocketImpl { static constexpr Handle invalid_socket = -1; #endif - explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {} + explicit SocketImpl( + Handle hnd, + c10::optional<::addrinfo> remote = c10::nullopt) noexcept + : hnd_{hnd}, remote_(remote) {} SocketImpl(const SocketImpl& other) = delete; @@ -174,12 +179,17 @@ class SocketImpl { return hnd_; } + const c10::optional<::addrinfo>& remote() const noexcept { + return remote_; + } + bool waitForInput(std::chrono::milliseconds timeout); private: bool setSocketFlag(int level, int optname, bool value) noexcept; Handle hnd_; + const c10::optional<::addrinfo> remote_; }; } // namespace c10d::detail @@ -207,7 +217,27 @@ struct formatter<::addrinfo> { NI_MAXSERV, NI_NUMERICSERV); if (r != 0) { - return fmt::format_to(ctx.out(), "?UNKNOWN?"); + // if we can't resolve the hostname, display the IP address + if (addr.ai_family == AF_INET) { + struct sockaddr_in* psai = (struct sockaddr_in*)addr.ai_addr; + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr.ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + NULL) { + return fmt::format_to(ctx.out(), "{}:{}", ip, psai->sin_port); + } + } else if (addr.ai_family == AF_INET6) { + struct sockaddr_in6* psai = (struct sockaddr_in6*)addr.ai_addr; + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop( + addr.ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + NULL) { + return fmt::format_to(ctx.out(), "[{}]:{}", ip, psai->sin6_port); + } + } + C10_THROW_ERROR( + DistNetworkError, + fmt::format( + "failed to format addr, unknown family={}", addr.ai_family)); } if (addr.ai_addr->sa_family == AF_INET) { @@ -234,7 +264,9 @@ struct formatter { ::socklen_t addr_len = sizeof(addr_s); - if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) { + auto fd = socket.handle(); + + if (::getsockname(fd, addr_ptr, &addr_len) != 0) { return fmt::format_to(ctx.out(), "?UNKNOWN?"); } @@ -242,7 +274,15 @@ struct formatter { addr.ai_addr = addr_ptr; addr.ai_addrlen = addr_len; - return fmt::format_to(ctx.out(), "{}", addr); + auto remote = socket.remote(); + std::string remoteStr = remote ? fmt::format("{}", *remote) : "none"; + + return fmt::format_to( + ctx.out(), + "SocketImpl(fd={}, addr={}, remote={})", + fd, + addr, + remoteStr); } }; @@ -297,7 +337,7 @@ std::unique_ptr SocketImpl::accept() const { *this, addr); - auto impl = std::make_unique(hnd); + auto impl = std::make_unique(hnd, addr); // Make sure that we do not "leak" our file descriptors to child processes. impl->closeOnExec(); @@ -413,22 +453,34 @@ bool SocketImpl::waitForInput(std::chrono::milliseconds timeout) { int res = pollFd(&pfd, 1, static_cast(timeout.count())); if (res > 0) { return true; + } else if (res == 0) { + C10D_WARNING( + "waitForInput: poll for socket {} returned 0, likely a timeout", + *this); + continue; } - std::error_code err = getSocketError(); + std::error_code err = getSocketError(); if (err == std::errc::operation_in_progress) { bool timedout = Clock::now() >= deadline; if (timedout) { return false; } C10D_WARNING( - "pollFB for socket {} returned operation_in_progress before a timeout", - hnd_); + "waitForInput: poll for socket {} returned operation_in_progress before a timeout", + *this); } else if (err != std::errc::interrupted) { - C10D_WARNING("While waitForInput, poolFD failed with {}.", err); + C10D_WARNING( + "waitForInput: poll for socket {} failed with res={}, err={}.", + *this, + res, + err); return false; } } while (Clock::now() < deadline); + + C10D_WARNING( + "waitForInput: socket {} timed out after {}ms", *this, timeout.count()); return false; } @@ -848,7 +900,7 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect( return ConnectResult::Error; } - socket_ = std::make_unique(hnd); + socket_ = std::make_unique(hnd, addr); socket_->enableNonBlocking(); @@ -1033,4 +1085,11 @@ bool Socket::waitForInput(std::chrono::milliseconds timeout) { return impl_->waitForInput(timeout); } +std::string Socket::repr() const { + if (impl_) { + return fmt::format("{}", *impl_); + } + return "Socket(no-impl)"; +} + } // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h index 52832722304cf6..9ca74f3143fea2 100644 --- a/torch/csrc/distributed/c10d/socket.h +++ b/torch/csrc/distributed/c10d/socket.h @@ -82,6 +82,8 @@ class Socket { bool waitForInput(std::chrono::milliseconds timeout); + std::string repr() const; + private: explicit Socket(std::unique_ptr&& impl) noexcept; From bd72e28314d8d63bb347becb8309f5ac7761c6b5 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 14 Jun 2024 23:21:01 +0000 Subject: [PATCH 034/171] [1/N] Change #include to #include (#128301) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301 Approved by: https://github.com/ezyang --- aten/src/ATen/CPUGeneratorImpl.h | 2 +- aten/src/ATen/InferSize.h | 2 +- aten/src/ATen/SavedTensorHooks.cpp | 2 +- aten/src/ATen/SavedTensorHooks.h | 2 +- aten/src/ATen/TensorIndexing.h | 2 +- aten/src/ATen/native/BatchLinearAlgebra.h | 2 +- aten/src/ATen/record_function.h | 2 +- c10/core/ConstantSymNodeImpl.h | 6 +- c10/core/ScalarTypeToTypeMeta.h | 2 +- c10/core/SymBool.h | 4 +- c10/core/SymInt.h | 4 +- c10/core/SymIntArrayRef.h | 4 +- c10/core/SymNodeImpl.h | 12 +-- c10/core/SymbolicShapeMeta.cpp | 4 +- c10/core/TensorImpl.cpp | 2 +- c10/core/TensorImpl.h | 14 ++-- c10/core/TensorOptions.h | 32 +++---- c10/core/UndefinedTensorImpl.cpp | 2 +- c10/core/impl/InlineDeviceGuard.h | 4 +- c10/core/impl/InlineStreamGuard.h | 4 +- c10/core/impl/PyObjectSlot.h | 8 +- c10/core/impl/TorchDispatchModeTLS.cpp | 12 +-- c10/cuda/CUDACachingAllocator.cpp | 6 +- c10/cuda/CUDAFunctions.cpp | 2 +- c10/cuda/CUDAGuard.h | 4 +- c10/cuda/impl/CUDAGuardImpl.h | 4 +- c10/test/core/DeviceGuard_test.cpp | 5 +- c10/test/core/SymInt_test.cpp | 2 +- c10/test/core/impl/InlineDeviceGuard_test.cpp | 16 ++-- c10/test/core/impl/InlineStreamGuard_test.cpp | 16 ++-- c10/test/util/optional_test.cpp | 32 +++---- c10/util/Backtrace.cpp | 10 +-- c10/util/OptionalArrayRef.h | 18 ++-- c10/xpu/test/impl/XPUStreamTest.cpp | 2 +- torch/csrc/Module.cpp | 8 +- torch/csrc/Storage.cpp | 2 +- .../csrc/api/include/torch/expanding_array.h | 2 +- torch/csrc/api/include/torch/fft.h | 84 +++++++++---------- torch/csrc/api/include/torch/nested.h | 12 +-- .../include/torch/nn/functional/activation.h | 6 +- .../include/torch/nn/functional/embedding.h | 6 +- .../api/include/torch/nn/functional/loss.h | 4 +- .../torch/nn/functional/normalization.h | 6 +- .../api/include/torch/nn/functional/pooling.h | 16 ++-- .../include/torch/nn/functional/upsampling.h | 28 +++---- .../api/include/torch/nn/modules/batchnorm.h | 4 +- .../csrc/api/include/torch/nn/modules/conv.h | 6 +- .../api/include/torch/nn/modules/pooling.h | 6 +- .../csrc/api/include/torch/nn/modules/utils.h | 2 +- .../api/include/torch/nn/options/activation.h | 10 +-- .../api/include/torch/nn/options/embedding.h | 24 +++--- .../csrc/api/include/torch/nn/options/loss.h | 4 +- .../include/torch/nn/options/normalization.h | 2 +- .../api/include/torch/nn/options/pooling.h | 8 +- .../api/include/torch/nn/options/upsampling.h | 14 ++-- .../api/include/torch/nn/options/vision.h | 2 +- .../api/include/torch/nn/utils/clip_grad.h | 4 +- .../torch/nn/utils/convert_parameters.h | 2 +- torch/csrc/api/include/torch/optim/lbfgs.h | 10 +-- .../csrc/api/include/torch/optim/optimizer.h | 10 +-- .../include/torch/serialize/input-archive.h | 10 +-- torch/csrc/api/include/torch/types.h | 4 +- torch/csrc/api/src/jit.cpp | 2 +- torch/csrc/api/src/nn/modules/activation.cpp | 8 +- torch/csrc/api/src/nn/modules/conv.cpp | 2 +- torch/csrc/api/src/nn/modules/embedding.cpp | 10 +-- torch/csrc/api/src/nn/modules/pooling.cpp | 20 ++--- torch/csrc/api/src/nn/modules/upsampling.cpp | 4 +- torch/csrc/api/src/optim/lbfgs.cpp | 16 ++-- .../csrc/api/src/serialize/input-archive.cpp | 8 +- torch/csrc/autograd/FunctionsManual.cpp | 22 ++--- torch/csrc/autograd/FunctionsManual.h | 2 +- torch/csrc/autograd/TraceTypeManual.cpp | 2 +- torch/csrc/autograd/VariableTypeManual.cpp | 2 +- torch/csrc/autograd/VariableTypeUtils.h | 4 +- torch/csrc/autograd/autograd.h | 4 +- .../autograd_not_implemented_fallback.cpp | 2 +- torch/csrc/autograd/engine.cpp | 6 +- torch/csrc/autograd/function.h | 4 +- .../csrc/autograd/functions/accumulate_grad.h | 2 +- torch/csrc/autograd/functions/comm.cpp | 2 +- torch/csrc/autograd/functions/comm.h | 6 +- torch/csrc/autograd/init.cpp | 8 +- torch/csrc/autograd/input_buffer.cpp | 6 +- torch/csrc/autograd/input_buffer.h | 2 +- torch/csrc/autograd/profiler_legacy.cpp | 2 +- torch/csrc/autograd/profiler_legacy.h | 6 +- torch/csrc/autograd/profiler_python.cpp | 10 +-- torch/csrc/autograd/python_function.cpp | 2 +- torch/csrc/autograd/python_function.h | 2 +- torch/csrc/autograd/python_variable.cpp | 8 +- .../autograd/python_variable_indexing.cpp | 4 +- torch/csrc/autograd/record_function_ops.h | 4 +- .../autograd/utils/grad_layout_contract.h | 2 +- .../csrc/autograd/utils/python_arg_parsing.h | 2 +- torch/csrc/autograd/variable.h | 12 +-- torch/csrc/cuda/comm.cpp | 2 +- torch/csrc/cuda/comm.h | 8 +- torch/csrc/cuda/memory_snapshot.h | 2 +- torch/csrc/cuda/nccl.h | 2 +- torch/csrc/cuda/python_nccl.cpp | 2 +- .../autograd/engine/dist_engine.cpp | 2 +- torch/csrc/distributed/c10d/NCCLUtils.cpp | 8 +- torch/csrc/distributed/c10d/NCCLUtils.hpp | 16 ++-- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 4 +- .../distributed/c10d/ProcessGroupGloo.cpp | 4 +- .../distributed/c10d/ProcessGroupGloo.hpp | 2 +- .../csrc/distributed/c10d/ProcessGroupMPI.cpp | 6 +- .../csrc/distributed/c10d/ProcessGroupMPI.hpp | 6 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 30 +++---- .../distributed/c10d/ProcessGroupNCCL.hpp | 8 +- torch/csrc/distributed/c10d/TCPStore.cpp | 4 +- torch/csrc/distributed/c10d/TCPStore.hpp | 6 +- torch/csrc/distributed/c10d/TraceUtils.h | 4 +- torch/csrc/distributed/c10d/Types.hpp | 2 +- torch/csrc/distributed/c10d/Utils.hpp | 2 +- torch/csrc/distributed/c10d/Work.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 4 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 4 +- torch/csrc/distributed/c10d/logger.cpp | 2 +- torch/csrc/distributed/c10d/reducer.cpp | 8 +- torch/csrc/distributed/c10d/reducer.hpp | 6 +- torch/csrc/distributed/c10d/reducer_cuda.cpp | 4 +- torch/csrc/distributed/c10d/reducer_timer.hpp | 2 +- torch/csrc/distributed/c10d/sequence_num.cpp | 10 +-- torch/csrc/distributed/c10d/sequence_num.hpp | 2 +- .../rpc/profiler/remote_profiler_manager.cpp | 4 +- .../rpc/profiler/remote_profiler_manager.h | 2 +- torch/csrc/distributed/rpc/py_rref.cpp | 2 +- .../csrc/distributed/rpc/python_functions.cpp | 4 +- .../rpc/request_callback_no_python.cpp | 2 +- torch/csrc/distributed/rpc/rref_impl.h | 2 +- torch/csrc/distributed/rpc/script_call.h | 2 +- .../csrc/distributed/rpc/tensorpipe_cuda.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_utils.cpp | 4 +- .../csrc/dynamo/python_compiled_autograd.cpp | 2 +- torch/csrc/functorch/init.cpp | 20 ++--- torch/csrc/inductor/aoti_torch/utils.h | 26 +++--- torch/csrc/jit/api/compilation_unit.h | 6 +- torch/csrc/jit/api/function_impl.h | 2 +- torch/csrc/jit/api/module.cpp | 4 +- torch/csrc/jit/api/module.h | 8 +- torch/csrc/jit/api/object.cpp | 2 +- torch/csrc/jit/api/object.h | 6 +- torch/csrc/jit/codegen/fuser/compiler.cpp | 2 +- .../jit/codegen/fuser/cpu/fused_kernel.cpp | 4 +- torch/csrc/jit/codegen/fuser/executor.cpp | 12 +-- torch/csrc/jit/codegen/fuser/kernel_spec.h | 4 +- .../csrc/jit/codegen/onednn/graph_helper.cpp | 2 +- .../jit/codegen/onednn/graph_rewriter.cpp | 2 +- .../jit/codegen/onednn/prepare_binary.cpp | 2 +- torch/csrc/jit/cuda/cuda.h | 4 +- torch/csrc/jit/frontend/builtin_functions.cpp | 2 +- .../frontend/canonicalize_modified_loop.cpp | 2 +- .../jit/frontend/concrete_module_type.cpp | 10 +-- .../jit/frontend/function_schema_parser.cpp | 4 +- torch/csrc/jit/frontend/ir_emitter.cpp | 30 +++---- .../csrc/jit/frontend/parse_string_literal.h | 8 +- torch/csrc/jit/frontend/parser.cpp | 4 +- torch/csrc/jit/frontend/schema_matching.cpp | 28 +++---- torch/csrc/jit/frontend/schema_matching.h | 8 +- .../csrc/jit/frontend/schema_type_parser.cpp | 6 +- .../csrc/jit/frontend/script_type_parser.cpp | 20 ++--- torch/csrc/jit/frontend/source_range.cpp | 2 +- torch/csrc/jit/frontend/source_range.h | 10 +-- torch/csrc/jit/frontend/sugared_value.cpp | 2 +- torch/csrc/jit/frontend/sugared_value.h | 14 ++-- torch/csrc/jit/frontend/tracer.cpp | 4 +- torch/csrc/jit/ir/alias_analysis.cpp | 14 ++-- torch/csrc/jit/ir/constants.cpp | 12 +-- torch/csrc/jit/ir/constants.h | 14 ++-- torch/csrc/jit/ir/ir.cpp | 12 +-- torch/csrc/jit/ir/ir.h | 18 ++-- torch/csrc/jit/ir/scope.h | 2 +- .../mobile/compatibility/backport_manager.cpp | 8 +- .../compatibility/runtime_compatibility.h | 2 +- torch/csrc/jit/mobile/flatbuffer_loader.cpp | 6 +- torch/csrc/jit/mobile/flatbuffer_loader.h | 12 +-- torch/csrc/jit/mobile/frame.h | 2 +- torch/csrc/jit/mobile/function.cpp | 4 +- torch/csrc/jit/mobile/import.cpp | 6 +- torch/csrc/jit/mobile/import.h | 6 +- torch/csrc/jit/mobile/import_data.h | 6 +- .../mobile/model_tracer/MobileModelRunner.h | 2 +- .../jit/mobile/model_tracer/TracerRunner.cpp | 8 +- torch/csrc/jit/mobile/module.cpp | 6 +- torch/csrc/jit/mobile/promoted_prim_ops.cpp | 2 +- .../operator_upgraders/upgraders_entry.cpp | 2 +- torch/csrc/jit/operator_upgraders/utils.cpp | 4 +- torch/csrc/jit/operator_upgraders/utils.h | 2 +- torch/csrc/jit/passes/autocast.cpp | 8 +- torch/csrc/jit/passes/canonicalize.cpp | 8 +- .../passes/canonicalize_graph_fuser_ops.cpp | 4 +- .../csrc/jit/passes/constant_propagation.cpp | 14 ++-- .../jit/passes/create_autodiff_subgraphs.cpp | 8 +- .../csrc/jit/passes/device_type_analysis.cpp | 6 +- torch/csrc/jit/passes/dtype_analysis.cpp | 6 +- torch/csrc/jit/passes/erase_number_types.cpp | 2 +- torch/csrc/jit/passes/freeze_module.cpp | 8 +- .../csrc/jit/passes/frozen_ops_to_mkldnn.cpp | 2 +- torch/csrc/jit/passes/graph_fuser.cpp | 4 +- .../csrc/jit/passes/graph_rewrite_helper.cpp | 2 +- .../jit/passes/inline_autodiff_subgraphs.cpp | 2 +- .../jit/passes/integer_value_refinement.cpp | 4 +- torch/csrc/jit/passes/onnx/constant_fold.cpp | 56 ++++++------- torch/csrc/jit/passes/onnx/constant_fold.h | 2 +- torch/csrc/jit/passes/onnx/constant_map.cpp | 20 ++--- .../jit/passes/onnx/function_extraction.cpp | 16 ++-- .../jit/passes/onnx/list_model_parameters.cpp | 2 +- .../pattern_conversion/pattern_conversion.cpp | 2 +- .../pattern_encapsulation.cpp | 2 +- torch/csrc/jit/passes/onnx/peephole.cpp | 8 +- .../jit/passes/onnx/scalar_type_analysis.cpp | 16 ++-- .../jit/passes/onnx/shape_type_inference.cpp | 26 +++--- .../passes/onnx/unpack_quantized_weights.cpp | 4 +- .../csrc/jit/passes/peephole_dict_idioms.cpp | 16 ++-- .../csrc/jit/passes/peephole_list_idioms.cpp | 8 +- torch/csrc/jit/passes/quantization/helper.cpp | 12 +-- torch/csrc/jit/passes/quantization/helper.h | 2 +- .../passes/quantization/insert_observers.cpp | 4 +- .../quantization/insert_quant_dequant.cpp | 8 +- torch/csrc/jit/passes/remove_mutation.h | 4 +- .../passes/replacement_of_old_operators.cpp | 2 +- torch/csrc/jit/passes/shape_analysis.cpp | 42 +++++----- .../jit/passes/symbolic_shape_analysis.cpp | 34 ++++---- .../csrc/jit/passes/symbolic_shape_cache.cpp | 4 +- .../passes/symbolic_shape_runtime_fusion.cpp | 2 +- torch/csrc/jit/passes/tensorexpr_fuser.cpp | 4 +- .../passes/utils/check_alias_annotation.cpp | 6 +- torch/csrc/jit/passes/utils/memory_dag.h | 2 +- .../csrc/jit/passes/utils/subgraph_utils.cpp | 2 +- torch/csrc/jit/python/init.cpp | 10 +-- torch/csrc/jit/python/module_python.h | 4 +- torch/csrc/jit/python/pybind_utils.cpp | 4 +- torch/csrc/jit/python/pybind_utils.h | 14 ++-- torch/csrc/jit/python/python_ir.cpp | 6 +- torch/csrc/jit/python/python_ivalue.h | 2 +- torch/csrc/jit/python/python_list.h | 4 +- .../csrc/jit/python/python_sugared_value.cpp | 18 ++-- torch/csrc/jit/python/python_sugared_value.h | 2 +- torch/csrc/jit/python/python_tree_views.cpp | 4 +- torch/csrc/jit/python/script_init.cpp | 12 +-- torch/csrc/jit/runtime/autodiff.cpp | 4 +- .../jit/runtime/decomposition_registry.cpp | 6 +- torch/csrc/jit/runtime/graph_executor.h | 2 +- torch/csrc/jit/runtime/graph_executor_impl.h | 2 +- torch/csrc/jit/runtime/interpreter.cpp | 6 +- torch/csrc/jit/runtime/interpreter.h | 6 +- torch/csrc/jit/runtime/jit_exception.h | 6 +- torch/csrc/jit/runtime/operator.h | 2 +- .../runtime/profiling_graph_executor_impl.cpp | 10 +-- torch/csrc/jit/runtime/register_ops_utils.h | 2 +- torch/csrc/jit/runtime/register_prim_ops.cpp | 8 +- .../jit/runtime/register_prim_ops_fulljit.cpp | 16 ++-- .../csrc/jit/runtime/register_special_ops.cpp | 8 +- .../runtime/simple_graph_executor_impl.cpp | 2 +- torch/csrc/jit/runtime/static/fusion.cpp | 4 +- torch/csrc/jit/runtime/static/impl.cpp | 8 +- torch/csrc/jit/runtime/static/ops.cpp | 68 +++++++-------- torch/csrc/jit/runtime/static/ops.h | 24 +++--- torch/csrc/jit/runtime/symbolic_script.cpp | 4 +- torch/csrc/jit/runtime/symbolic_script.h | 2 +- .../jit/runtime/symbolic_shape_registry.cpp | 6 +- .../callstack_debug_info_serialization.cpp | 2 +- torch/csrc/jit/serialization/export.cpp | 2 +- .../jit/serialization/export_bytecode.cpp | 2 +- .../csrc/jit/serialization/export_module.cpp | 4 +- .../serialization/flatbuffer_serializer.cpp | 4 +- torch/csrc/jit/serialization/import.h | 18 ++-- .../csrc/jit/serialization/import_source.cpp | 4 +- torch/csrc/jit/serialization/import_source.h | 4 +- torch/csrc/jit/serialization/pickle.cpp | 6 +- torch/csrc/jit/serialization/pickler.cpp | 2 +- torch/csrc/jit/serialization/python_print.cpp | 2 +- .../source_range_serialization.cpp | 4 +- torch/csrc/jit/tensorexpr/codegen.cpp | 2 +- torch/csrc/jit/tensorexpr/eval.cpp | 2 +- torch/csrc/jit/tensorexpr/expr.h | 14 ++-- .../jit/tensorexpr/external_functions.cpp | 16 ++-- .../csrc/jit/tensorexpr/external_functions.h | 4 +- torch/csrc/jit/tensorexpr/graph_opt.cpp | 4 +- torch/csrc/jit/tensorexpr/ir.h | 2 +- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 12 +-- torch/csrc/jit/tensorexpr/kernel.cpp | 16 ++-- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 6 +- torch/csrc/jit/tensorexpr/llvm_codegen.h | 14 ++-- torch/csrc/jit/tensorexpr/llvm_jit.h | 2 +- .../csrc/jit/tensorexpr/operators/conv2d.cpp | 4 +- torch/csrc/jit/tensorexpr/operators/misc.cpp | 2 +- .../csrc/jit/tensorexpr/operators/pointwise.h | 2 +- .../jit/tensorexpr/operators/quantization.cpp | 4 +- .../csrc/jit/tensorexpr/operators/softmax.cpp | 12 +-- torch/csrc/jit/tensorexpr/tensor.cpp | 24 +++--- torch/csrc/jit/tensorexpr/tensor.h | 8 +- torch/csrc/jit/testing/file_check.cpp | 10 +-- torch/csrc/lazy/backend/backend_device.cpp | 12 +-- torch/csrc/lazy/backend/backend_device.h | 2 +- torch/csrc/lazy/core/ir_builder.h | 6 +- torch/csrc/lazy/core/ir_dump_util.cpp | 6 +- torch/csrc/lazy/core/lazy_graph_executor.cpp | 2 +- torch/csrc/lazy/core/shape.cpp | 4 +- torch/csrc/lazy/core/shape.h | 4 +- torch/csrc/lazy/core/shape_inference.h | 2 +- torch/csrc/lazy/core/tensor.cpp | 8 +- torch/csrc/lazy/core/unique.h | 2 +- torch/csrc/lazy/core/util.h | 4 +- torch/csrc/lazy/python/python_util.cpp | 4 +- torch/csrc/lazy/python/python_util.h | 2 +- torch/csrc/lazy/ts_backend/ir_builder.h | 2 +- .../lazy/ts_backend/ts_eager_fallback.cpp | 2 +- .../lazy/ts_backend/ts_native_functions.cpp | 8 +- torch/csrc/profiler/collection.cpp | 6 +- torch/csrc/profiler/collection.h | 2 +- torch/csrc/profiler/python/init.cpp | 2 +- torch/csrc/profiler/unwind/unwind.cpp | 4 +- torch/csrc/profiler/unwind/unwind.h | 2 +- torch/csrc/profiler/unwind/unwind_error.h | 2 +- torch/csrc/profiler/util.h | 2 +- torch/csrc/tensor/python_tensor.cpp | 2 +- torch/csrc/utils/nested.cpp | 2 +- torch/csrc/utils/python_arg_parser.cpp | 6 +- torch/csrc/utils/python_arg_parser.h | 28 +++---- torch/csrc/utils/python_dispatch.cpp | 24 +++--- torch/csrc/utils/python_raii.h | 6 +- torch/csrc/utils/python_symnode.h | 2 +- torch/csrc/utils/schema_info.cpp | 4 +- torch/csrc/utils/tensor_new.cpp | 20 ++--- torch/csrc/utils/torch_dispatch_mode.h | 2 +- torch/custom_class_detail.h | 2 +- torch/library.h | 14 ++-- 330 files changed, 1208 insertions(+), 1207 deletions(-) diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h index 34dd33a475b917..e15ca23d6bf748 100644 --- a/aten/src/ATen/CPUGeneratorImpl.h +++ b/aten/src/ATen/CPUGeneratorImpl.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include namespace at { diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 411cf12d513414..4bf820312d2f68 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp index 7aa9b0f02ea365..6837348c932101 100644 --- a/aten/src/ATen/SavedTensorHooks.cpp +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -32,7 +32,7 @@ void SavedTensorDefaultHooks::disable(const std::string& message) { } void SavedTensorDefaultHooks::enable() { - tls.disabled_error_message = c10::nullopt; + tls.disabled_error_message = std::nullopt; } /* static */ bool SavedTensorDefaultHooks::set_tracing(bool is_tracing) { diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h index b69b9c25e8e6a5..9cf1ea37c35390 100644 --- a/aten/src/ATen/SavedTensorHooks.h +++ b/aten/src/ATen/SavedTensorHooks.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include #include diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index eb36c0e02fa4db..1fe9e7ebdcb012 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -5,8 +5,8 @@ #include #include #include -#include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index c8402640aa08ac..58d46aacd47314 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 014260fb220f89..63fbcb55e96d2b 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/c10/core/ConstantSymNodeImpl.h b/c10/core/ConstantSymNodeImpl.h index 3c0fb66f7469fe..791a81cace4176 100644 --- a/c10/core/ConstantSymNodeImpl.h +++ b/c10/core/ConstantSymNodeImpl.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include #include @@ -73,14 +73,14 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl { if constexpr (is_int_()) { return ::std::get(value_); } else { - return c10::nullopt; + return std::nullopt; } } std::optional constant_bool() override { if constexpr (is_bool_()) { return ::std::get(value_); } else { - return c10::nullopt; + return std::nullopt; } } bool is_constant() override { diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h index d2694c96221eb4..5e9e1a936af5af 100644 --- a/c10/core/ScalarTypeToTypeMeta.h +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -30,7 +30,7 @@ inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { inline optional optTypeMetaToScalarType( optional type_meta) { if (!type_meta.has_value()) { - return c10::nullopt; + return std::nullopt; } return type_meta->toScalarType(); } diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index 9f9f141293a375..06ce32c1a7160b 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -3,9 +3,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -68,7 +68,7 @@ class C10_API SymBool { std::optional maybe_as_bool() const { if (!is_heap_allocated()) { - return c10::make_optional(data_); + return std::make_optional(data_); } return toSymNodeImplUnowned()->constant_bool(); } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 025c351334a016..eef34aac24ca6d 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -231,7 +231,7 @@ class C10_API SymInt { std::optional maybe_as_int() const { if (!is_heap_allocated()) { - return c10::make_optional(data_); + return std::make_optional(data_); } auto* node = toSymNodeImplUnowned(); if (auto c = node->constant_int()) { diff --git a/c10/core/SymIntArrayRef.h b/c10/core/SymIntArrayRef.h index 760f4ba4e79a21..ce7253c60ec59f 100644 --- a/c10/core/SymIntArrayRef.h +++ b/c10/core/SymIntArrayRef.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include namespace c10 { using SymIntArrayRef = ArrayRef; @@ -23,7 +23,7 @@ inline std::optional asIntArrayRefSlowOpt( c10::SymIntArrayRef ar) { for (const c10::SymInt& sci : ar) { if (sci.is_heap_allocated()) { - return c10::nullopt; + return std::nullopt; } } diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index bb92b09775b7b4..39e4bbbc2c6cd7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -3,9 +3,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -207,19 +207,19 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { TORCH_CHECK(false, "NYI"); }; virtual std::optional nested_int() { - return c10::nullopt; + return std::nullopt; } virtual std::optional nested_int_coeff() { - return c10::nullopt; + return std::nullopt; } virtual std::optional constant_int() { - return c10::nullopt; + return std::nullopt; } virtual std::optional constant_bool() { - return c10::nullopt; + return std::nullopt; } virtual std::optional maybe_as_int() { - return c10::nullopt; + return std::nullopt; } virtual bool is_constant() { return false; diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index 62b03d36ec71c9..b59a95a4a2faf4 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -56,7 +56,7 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { // Couldn't find. Tell the caller to do the normal computation // Alternately, if everything is hinted, we want the normal computation // too - return c10::nullopt; + return std::nullopt; } // Populate the SymNode array std::vector size_nodes; @@ -69,7 +69,7 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { for (const auto& s : strides) { stride_nodes.emplace_back(s.wrap_node(base)); } - return c10::make_optional( + return std::make_optional( std::tuple, std::vector>( std::move(base), std::move(size_nodes), std::move(stride_nodes))); } diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 516a61f0200462..130292aaa70d6a 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -8,9 +8,9 @@ #include #include #include -#include #include #include +#include #include diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 877c1c09543cb5..67543614c021bc 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -24,12 +24,12 @@ #include #include #include -#include #include #include #include #include #include +#include #include #include @@ -233,8 +233,8 @@ struct C10_API ExtraMeta { std::unique_ptr symbolic_shape_meta_ = nullptr; std::unique_ptr named_tensor_meta_ = nullptr; intrusive_ptr backend_meta_ = nullptr; - std::optional custom_data_ptr_error_msg_ = c10::nullopt; - std::optional custom_storage_error_msg_ = c10::nullopt; + std::optional custom_data_ptr_error_msg_ = std::nullopt; + std::optional custom_storage_error_msg_ = std::nullopt; ExtraMeta() = default; ExtraMeta(const ExtraMeta& other) { @@ -260,8 +260,8 @@ struct C10_API ExtraMeta { std::unique_ptr symbolic_shape_meta, std::unique_ptr named_tensor_meta, intrusive_ptr backend_meta, - std::optional custom_data_ptr_error_msg = c10::nullopt, - std::optional custom_storage_access_error_msg = c10::nullopt) + std::optional custom_data_ptr_error_msg = std::nullopt, + std::optional custom_storage_access_error_msg = std::nullopt) : symbolic_shape_meta_(std::move(symbolic_shape_meta)), named_tensor_meta_(std::move(named_tensor_meta)), backend_meta_(std::move(backend_meta)), @@ -1737,7 +1737,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_sizes_and_strides( c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, - std::optional storage_offset = c10::nullopt); + std::optional storage_offset = std::nullopt); // This is renamed to avoid breaking overload BC void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes); void generic_set_sizes_contiguous(c10::IntArrayRef sizes) { @@ -1834,7 +1834,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_sizes_and_strides( IntArrayRef new_size, IntArrayRef new_stride, - std::optional storage_offset = c10::nullopt) { + std::optional storage_offset = std::nullopt) { TORCH_CHECK( allow_tensor_metadata_change(), "set_sizes_and_strides ", diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index d99005d3d28f85..9c23c767ffc5ef 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include @@ -284,10 +284,10 @@ struct C10_API TensorOptions { return has_device_; } - /// Returns the device of the `TensorOptions`, or `c10::nullopt` if + /// Returns the device of the `TensorOptions`, or `std::nullopt` if /// device is not specified. std::optional device_opt() const noexcept { - return has_device_ ? c10::make_optional(device_) : c10::nullopt; + return has_device_ ? std::make_optional(device_) : std::nullopt; } /// Returns the device index of the `TensorOptions`. @@ -305,10 +305,10 @@ struct C10_API TensorOptions { return has_dtype_; } - /// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if + /// Returns the dtype of the `TensorOptions`, or `std::nullopt` if /// device is not specified. std::optional dtype_opt() const noexcept { - return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt; + return has_dtype_ ? std::make_optional(dtype_) : std::nullopt; } /// Returns the layout of the `TensorOptions`. @@ -321,10 +321,10 @@ struct C10_API TensorOptions { return has_layout_; } - /// Returns the layout of the `TensorOptions`, or `c10::nullopt` if + /// Returns the layout of the `TensorOptions`, or `std::nullopt` if /// layout is not specified. std::optional layout_opt() const noexcept { - return has_layout_ ? c10::make_optional(layout_) : c10::nullopt; + return has_layout_ ? std::make_optional(layout_) : std::nullopt; } /// Returns the `requires_grad` property of the `TensorOptions`. @@ -338,10 +338,10 @@ struct C10_API TensorOptions { } /// Returns the `requires_grad` property of the `TensorOptions`, or - /// `c10::nullopt` if `requires_grad` is not specified. + /// `std::nullopt` if `requires_grad` is not specified. std::optional requires_grad_opt() const noexcept { - return has_requires_grad_ ? c10::make_optional(requires_grad_) - : c10::nullopt; + return has_requires_grad_ ? std::make_optional(requires_grad_) + : std::nullopt; } /// Returns the `pinned_memory` property of the `TensorOptions`. @@ -378,10 +378,10 @@ struct C10_API TensorOptions { } /// Returns the `pinned_memory` property of the `TensorOptions`, or - /// `c10::nullopt` if `pinned_memory` is not specified. + /// `std::nullopt` if `pinned_memory` is not specified. std::optional pinned_memory_opt() const noexcept { - return has_pinned_memory_ ? c10::make_optional(pinned_memory_) - : c10::nullopt; + return has_pinned_memory_ ? std::make_optional(pinned_memory_) + : std::nullopt; } /// Returns whether the `memory_layout` is specified @@ -393,10 +393,10 @@ struct C10_API TensorOptions { // behavior of memory_format varies from function to function. /// Returns the `memory_layout` property of `TensorOptions, or - /// `c10::nullopt` if `memory_format` is not specified. + /// `std::nullopt` if `memory_format` is not specified. std::optional memory_format_opt() const noexcept { - return has_memory_format_ ? c10::make_optional(memory_format_) - : c10::nullopt; + return has_memory_format_ ? std::make_optional(memory_format_) + : std::nullopt; } // Resolves the ATen backend specified by the current construction axes. diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index 1b16a5d5b9fd7e..2a715d78bdb767 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -5,7 +5,7 @@ namespace c10 { // should this use the globalContext? Can it get a context passed in somehow? UndefinedTensorImpl::UndefinedTensorImpl() - : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) { + : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), std::nullopt) { set_storage_access_should_throw(); // TODO: accessing the sizes on an undefined tensor is not meaningful // and should error too, but empirically it does not! diff --git a/c10/core/impl/InlineDeviceGuard.h b/c10/core/impl/InlineDeviceGuard.h index 3e9f91eff61700..a70c194efccf8c 100644 --- a/c10/core/impl/InlineDeviceGuard.h +++ b/c10/core/impl/InlineDeviceGuard.h @@ -404,7 +404,7 @@ class InlineOptionalDeviceGuard { /// Returns the device that was set immediately prior to initialization of /// the, guard, or nullopt if the guard is uninitialized. optional original_device() const { - return guard_.has_value() ? make_optional(guard_->original_device()) + return guard_.has_value() ? std::make_optional(guard_->original_device()) : nullopt; } @@ -412,7 +412,7 @@ class InlineOptionalDeviceGuard { /// either from construction, or via set_device, if the guard is initialized, /// or nullopt if the guard is uninitialized. optional current_device() const { - return guard_.has_value() ? make_optional(guard_->current_device()) + return guard_.has_value() ? std::make_optional(guard_->current_device()) : nullopt; } diff --git a/c10/core/impl/InlineStreamGuard.h b/c10/core/impl/InlineStreamGuard.h index b99e7db72addc6..5ac913c4ff7fff 100644 --- a/c10/core/impl/InlineStreamGuard.h +++ b/c10/core/impl/InlineStreamGuard.h @@ -173,7 +173,7 @@ class InlineOptionalStreamGuard { /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. optional original_stream() const { - return guard_.has_value() ? make_optional(guard_->original_stream()) + return guard_.has_value() ? std::make_optional(guard_->original_stream()) : nullopt; } @@ -181,7 +181,7 @@ class InlineOptionalStreamGuard { /// either from construction, or via reset_stream, if the guard is /// initialized, or nullopt if the guard is uninitialized. optional current_stream() const { - return guard_.has_value() ? make_optional(guard_->current_stream()) + return guard_.has_value() ? std::make_optional(guard_->current_stream()) : nullopt; } diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 518b0e63e49217..8f2833b5c7da41 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include @@ -106,13 +106,13 @@ struct C10_API PyObjectSlot { // after we query here. The only time when we can conclude a tensor // is definitely uninitialized is when we have just allocated it and // it cannot have escaped to other threads yet - return c10::nullopt; + return std::nullopt; } else if (interpreter == self_interpreter) { // NB: pyobj_ could still be null! if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return c10::nullopt; + return std::nullopt; } else { - return c10::make_optional(_unchecked_untagged_pyobj()); + return std::make_optional(_unchecked_untagged_pyobj()); } } else { TORCH_CHECK( diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index f1847cb005b4ce..c9a3274ed896c3 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -16,7 +16,7 @@ bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) { if (!skip_infra_modes) { for (const auto i : c10::irange( static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { - if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + if (torchDispatchModeState.infra_modes_[i] != std::nullopt) { return true; } } @@ -48,7 +48,7 @@ const std::shared_ptr TorchDispatchModeTLS:: if (torchDispatchModeState.infra_modes_[i].has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) out = std::move(torchDispatchModeState.infra_modes_[i].value()); - torchDispatchModeState.infra_modes_[i] = c10::nullopt; + torchDispatchModeState.infra_modes_[i] = std::nullopt; break; } } @@ -70,7 +70,7 @@ const std:: if (torchDispatchModeState.infra_modes_[i].has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) auto out_mode = torchDispatchModeState.infra_modes_[i].value(); - torchDispatchModeState.infra_modes_[i] = c10::nullopt; + torchDispatchModeState.infra_modes_[i] = std::nullopt; if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( @@ -114,7 +114,7 @@ int64_t TorchDispatchModeTLS::stack_len() { int64_t infra_modes_len = 0; for (const auto i : c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { - if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { + if (torchDispatchModeState.infra_modes_[i] != std::nullopt) { infra_modes_len += 1; } } @@ -131,7 +131,7 @@ void TorchDispatchModeTLS::set_mode( TorchDispatchModeKey mode_key) { TORCH_CHECK( torchDispatchModeState.infra_modes_[static_cast(mode_key)] == - c10::nullopt, + std::nullopt, "trying to set the current ", to_string(mode_key), ", but one already exists"); @@ -149,7 +149,7 @@ const std::optional> TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) { auto out = torchDispatchModeState.infra_modes_[static_cast(mode_key)]; torchDispatchModeState.infra_modes_[static_cast(mode_key)] = - c10::nullopt; + std::nullopt; if (out.has_value() && !any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 11bea6056e9d85..e4535292cebaac 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -411,7 +411,7 @@ struct ExpandableSegment { return rangeFromHandles(begin, end); } while (end > handles_.size()) { - handles_.emplace_back(c10::nullopt); + handles_.emplace_back(std::nullopt); } for (auto i : c10::irange(begin, end)) { TORCH_INTERNAL_ASSERT(!handles_.at(i)); @@ -426,7 +426,7 @@ struct ExpandableSegment { if (status == CUDA_ERROR_OUT_OF_MEMORY) { for (auto j : c10::irange(begin, i)) { auto h = handles_.at(j).value(); - handles_.at(j) = c10::nullopt; + handles_.at(j) = std::nullopt; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h)); } trimHandles(); @@ -507,7 +507,7 @@ struct ExpandableSegment { C10_CUDA_CHECK(cudaStreamSynchronize(stream_)); for (auto i : c10::irange(begin, end)) { CUmemGenericAllocationHandle h = handles_.at(i).value(); - handles_.at(i) = c10::nullopt; + handles_.at(i) = std::nullopt; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_( ptr_ + segment_size_ * i, segment_size_)); C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h)); diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 2b53eb4d7c7cb7..8d88000b89db94 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -166,7 +166,7 @@ std::optional getDeviceIndexWithPrimaryContext() { return device_index; } } - return c10::nullopt; + return std::nullopt; } namespace _internal { diff --git a/c10/cuda/CUDAGuard.h b/c10/cuda/CUDAGuard.h index 254522893d5e08..65f5c5d191b7fb 100644 --- a/c10/cuda/CUDAGuard.h +++ b/c10/cuda/CUDAGuard.h @@ -242,7 +242,7 @@ struct OptionalCUDAStreamGuard { optional original_stream() const { auto r = guard_.original_stream(); if (r.has_value()) { - return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); } else { return nullopt; } @@ -254,7 +254,7 @@ struct OptionalCUDAStreamGuard { optional current_stream() const { auto r = guard_.current_stream(); if (r.has_value()) { - return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); } else { return nullopt; } diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index ec50c8152b33e2..1ef2fcb2c08f4d 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -14,9 +14,9 @@ #include #include #include -#include #include #include +#include namespace c10::cuda::impl { @@ -45,7 +45,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); C10_CUDA_CHECK_WARN(err); if (err != cudaSuccess) { - return c10::nullopt; + return std::nullopt; } return Device(DeviceType::CUDA, device); } diff --git a/c10/test/core/DeviceGuard_test.cpp b/c10/test/core/DeviceGuard_test.cpp index 63049ae7b555a2..0869ea1168d167 100644 --- a/c10/test/core/DeviceGuard_test.cpp +++ b/c10/test/core/DeviceGuard_test.cpp @@ -36,6 +36,7 @@ TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) { g.reset_device(Device(DeviceType::HIP, 2), &hip_impl); ASSERT_EQ(FakeGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(FakeGuardImpl::getDeviceIndex(), 2); - ASSERT_EQ(g.current_device(), make_optional(Device(DeviceType::HIP, 2))); - ASSERT_EQ(g.original_device(), make_optional(Device(DeviceType::HIP, 0))); + ASSERT_EQ(g.current_device(), std::make_optional(Device(DeviceType::HIP, 2))); + ASSERT_EQ( + g.original_device(), std::make_optional(Device(DeviceType::HIP, 0))); } diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index 8055ec7a325111..7cefa1e4a771bf 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -8,7 +8,7 @@ using namespace c10; #ifndef C10_MOBILE static void check(int64_t value) { const auto i = SymInt(value); - EXPECT_EQ(i.maybe_as_int(), c10::make_optional(value)); + EXPECT_EQ(i.maybe_as_int(), std::make_optional(value)); } TEST(SymIntTest, ConcreteInts) { diff --git a/c10/test/core/impl/InlineDeviceGuard_test.cpp b/c10/test/core/impl/InlineDeviceGuard_test.cpp index 69db93e307bfe8..2b4ad0c5b2381f 100644 --- a/c10/test/core/impl/InlineDeviceGuard_test.cpp +++ b/c10/test/core/impl/InlineDeviceGuard_test.cpp @@ -170,12 +170,12 @@ TEST(InlineOptionalDeviceGuard, SetDevice) { MaybeTestGuard g; DeviceIndex i = 1; g.set_device(dev(i)); - ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), make_optional(dev(i))); + ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); g.set_device(dev(i)); - ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), make_optional(dev(i))); + ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); } @@ -185,11 +185,11 @@ TEST(InlineOptionalDeviceGuard, SetIndex) { DeviceIndex i = 1; MaybeTestGuard g; g.set_index(i); - ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), make_optional(dev(i))); + ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); g.set_index(i); - ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), make_optional(dev(i))); + ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); } diff --git a/c10/test/core/impl/InlineStreamGuard_test.cpp b/c10/test/core/impl/InlineStreamGuard_test.cpp index 692504cebd1ccc..06c4b96ef913ef 100644 --- a/c10/test/core/impl/InlineStreamGuard_test.cpp +++ b/c10/test/core/impl/InlineStreamGuard_test.cpp @@ -109,8 +109,8 @@ TEST(InlineOptionalStreamGuard, Constructor) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), make_optional(stream(1, 2))); + ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -120,8 +120,8 @@ TEST(InlineOptionalStreamGuard, Constructor) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), make_optional(stream(1, 2))); + ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -146,8 +146,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamSameDevice) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 3); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), make_optional(stream(1, 3))); + ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 3))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -164,8 +164,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) { ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 3); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), make_optional(stream(2, 3))); + ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), std::make_optional(stream(2, 3))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 0); diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp index aa4c5a527ce667..e9496d9dc2887e 100644 --- a/c10/test/util/optional_test.cpp +++ b/c10/test/util/optional_test.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -67,7 +67,7 @@ TYPED_TEST(OptionalTest, Empty) { EXPECT_FALSE(empty.has_value()); // NOLINTNEXTLINE(bugprone-unchecked-optional-access,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(empty.value(), c10::bad_optional_access); + EXPECT_THROW(empty.value(), std::bad_optional_access); } TYPED_TEST(OptionalTest, Initialized) { @@ -111,32 +111,32 @@ TEST_P(SelfCompareTest, SelfCompare) { INSTANTIATE_TEST_SUITE_P( nullopt, SelfCompareTest, - testing::Values(c10::nullopt)); + testing::Values(std::nullopt)); INSTANTIATE_TEST_SUITE_P( int, SelfCompareTest, - testing::Values(c10::make_optional(2))); + testing::Values(std::make_optional(2))); TEST(OptionalTest, Nullopt) { std::optional x = 2; - EXPECT_THAT(c10::nullopt, Not(Eq(x))); - EXPECT_THAT(x, Not(Eq(c10::nullopt))); + EXPECT_THAT(std::nullopt, Not(Eq(x))); + EXPECT_THAT(x, Not(Eq(std::nullopt))); - EXPECT_THAT(x, Ne(c10::nullopt)); - EXPECT_THAT(c10::nullopt, Ne(x)); + EXPECT_THAT(x, Ne(std::nullopt)); + EXPECT_THAT(std::nullopt, Ne(x)); - EXPECT_THAT(x, Not(Lt(c10::nullopt))); - EXPECT_THAT(c10::nullopt, Lt(x)); + EXPECT_THAT(x, Not(Lt(std::nullopt))); + EXPECT_THAT(std::nullopt, Lt(x)); - EXPECT_THAT(x, Not(Le(c10::nullopt))); - EXPECT_THAT(c10::nullopt, Le(x)); + EXPECT_THAT(x, Not(Le(std::nullopt))); + EXPECT_THAT(std::nullopt, Le(x)); - EXPECT_THAT(x, Gt(c10::nullopt)); - EXPECT_THAT(c10::nullopt, Not(Gt(x))); + EXPECT_THAT(x, Gt(std::nullopt)); + EXPECT_THAT(std::nullopt, Not(Gt(x))); - EXPECT_THAT(x, Ge(c10::nullopt)); - EXPECT_THAT(c10::nullopt, Not(Ge(x))); + EXPECT_THAT(x, Ge(std::nullopt)); + EXPECT_THAT(std::nullopt, Not(Ge(x))); } // Ensure comparisons work... diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index 7d0fedbb335a29..d461267000befc 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -1,7 +1,7 @@ #include -#include #include #include +#include #include #include @@ -150,19 +150,19 @@ std::optional parse_frame_information( auto function_name_start = frame_string.find('('); if (function_name_start == std::string::npos) { - return c10::nullopt; + return std::nullopt; } function_name_start += 1; auto offset_start = frame_string.find('+', function_name_start); if (offset_start == std::string::npos) { - return c10::nullopt; + return std::nullopt; } offset_start += 1; const auto offset_end = frame_string.find(')', offset_start); if (offset_end == std::string::npos) { - return c10::nullopt; + return std::nullopt; } frame.object_file = frame_string.substr(0, function_name_start - 1); @@ -186,7 +186,7 @@ std::optional parse_frame_information( skip >> frame.offset_into_function; #else #warning Unknown standard library, backtraces may have incomplete debug information - return c10::nullopt; + return std::nullopt; #endif // defined(__GLIBCXX__) // Some system-level functions don't have sufficient debug information, so diff --git a/c10/util/OptionalArrayRef.h b/c10/util/OptionalArrayRef.h index 98237bba92f56d..ae4f4f1f2c67bd 100644 --- a/c10/util/OptionalArrayRef.h +++ b/c10/util/OptionalArrayRef.h @@ -12,9 +12,9 @@ #pragma once #include -#include #include #include +#include #include #include @@ -27,16 +27,16 @@ class OptionalArrayRef final { constexpr OptionalArrayRef() noexcept = default; - constexpr OptionalArrayRef(nullopt_t) noexcept {} + constexpr OptionalArrayRef(std::nullopt_t) noexcept {} OptionalArrayRef(const OptionalArrayRef& other) = default; OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; - constexpr OptionalArrayRef(const optional>& other) noexcept + constexpr OptionalArrayRef(const std::optional>& other) noexcept : wrapped_opt_array_ref(other) {} - constexpr OptionalArrayRef(optional>&& other) noexcept + constexpr OptionalArrayRef(std::optional>&& other) noexcept : wrapped_opt_array_ref(std::move(other)) {} constexpr OptionalArrayRef(const T& value) noexcept @@ -89,8 +89,8 @@ class OptionalArrayRef final { // Assignment - constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { - wrapped_opt_array_ref = c10::nullopt; + constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept { + wrapped_opt_array_ref = std::nullopt; return *this; } @@ -99,13 +99,13 @@ class OptionalArrayRef final { OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; constexpr OptionalArrayRef& operator=( - const optional>& other) noexcept { + const std::optional>& other) noexcept { wrapped_opt_array_ref = other; return *this; } constexpr OptionalArrayRef& operator=( - optional>&& other) noexcept { + std::optional>&& other) noexcept { wrapped_opt_array_ref = std::move(other); return *this; } @@ -213,7 +213,7 @@ class OptionalArrayRef final { } private: - optional> wrapped_opt_array_ref; + std::optional> wrapped_opt_array_ref; }; using OptionalIntArrayRef = OptionalArrayRef; diff --git a/c10/xpu/test/impl/XPUStreamTest.cpp b/c10/xpu/test/impl/XPUStreamTest.cpp index 01a1dbb62621b2..6cbe3ae6721587 100644 --- a/c10/xpu/test/impl/XPUStreamTest.cpp +++ b/c10/xpu/test/impl/XPUStreamTest.cpp @@ -1,9 +1,9 @@ #include -#include #include #include #include +#include #include #include diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 00a2c0bbe30267..96788c5d79f376 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,8 +1,8 @@ #include -#include #include #include #include +#include #ifndef _MSC_VER #include @@ -1817,7 +1817,7 @@ Call this whenever a new thread is created in order to propagate values from transposed_, output_padding_, std::move(groups_), - c10::nullopt); + std::nullopt); }, py::arg("input"), py::arg("weight"), @@ -1842,7 +1842,7 @@ Call this whenever a new thread is created in order to propagate values from at::SymIntArrayRef output_padding_, c10::SymInt groups_, std::optional> bias_sizes_opt) { - c10::OptionalArrayRef ref = c10::nullopt; + c10::OptionalArrayRef ref = std::nullopt; if (bias_sizes_opt) { ref = (*bias_sizes_opt); } @@ -2031,7 +2031,7 @@ Call this whenever a new thread is created in order to propagate values from py_module.def( "_get_accelerator", - [](std::optional check = c10::nullopt) { + [](std::optional check = std::nullopt) { return c10::Device( at::getAccelerator(check.value_or(false)) .value_or(c10::DeviceType::CPU), diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index aa5584abd39e4c..77520b6f1cdb1f 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -153,7 +153,7 @@ static bool THPStorage_isPreservable(THPStorage* self) { if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/true) != - c10::make_optional((PyObject*)self)) { + std::make_optional((PyObject*)self)) { return false; } if (storage.use_count() <= 1) { diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h index f0901b06af68cb..62c12d2e0ac8b4 100644 --- a/torch/csrc/api/include/torch/expanding_array.h +++ b/torch/csrc/api/include/torch/expanding_array.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include #include diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index d9a3430a7a2496..ef6d9b1bc23620 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -15,9 +15,9 @@ namespace fft { /// ``` inline Tensor fft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_fft_symint(self, n, dim, norm); } @@ -31,9 +31,9 @@ inline Tensor fft( /// ``` inline Tensor ifft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_ifft_symint(self, n, dim, norm); } @@ -47,9 +47,9 @@ inline Tensor ifft( /// ``` inline Tensor fft2( const Tensor& self, - OptionalIntArrayRef s = c10::nullopt, + OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_fft2(self, s, dim, norm); } @@ -63,9 +63,9 @@ inline Tensor fft2( /// ``` inline Tensor ifft2( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_ifft2(self, s, dim, norm); } @@ -79,9 +79,9 @@ inline Tensor ifft2( /// ``` inline Tensor fftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, - at::OptionalIntArrayRef dim = c10::nullopt, - std::optional norm = c10::nullopt) { + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { return torch::fft_fftn(self, s, dim, norm); } @@ -95,9 +95,9 @@ inline Tensor fftn( /// ``` inline Tensor ifftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, - at::OptionalIntArrayRef dim = c10::nullopt, - std::optional norm = c10::nullopt) { + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { return torch::fft_ifftn(self, s, dim, norm); } @@ -112,9 +112,9 @@ inline Tensor ifftn( /// ``` inline Tensor rfft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_rfft_symint(self, n, dim, norm); } @@ -131,9 +131,9 @@ inline Tensor rfft( /// ``` inline Tensor irfft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_irfft_symint(self, n, dim, norm); } @@ -147,9 +147,9 @@ inline Tensor irfft( /// ``` inline Tensor rfft2( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_rfft2(self, s, dim, norm); } @@ -163,9 +163,9 @@ inline Tensor rfft2( /// ``` inline Tensor irfft2( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_irfft2(self, s, dim, norm); } @@ -179,9 +179,9 @@ inline Tensor irfft2( /// ``` inline Tensor rfftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, - at::OptionalIntArrayRef dim = c10::nullopt, - std::optional norm = c10::nullopt) { + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { return torch::fft_rfftn(self, s, dim, norm); } @@ -195,9 +195,9 @@ inline Tensor rfftn( /// ``` inline Tensor irfftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, - at::OptionalIntArrayRef dim = c10::nullopt, - std::optional norm = c10::nullopt) { + at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef dim = std::nullopt, + std::optional norm = std::nullopt) { return torch::fft_irfftn(self, s, dim, norm); } @@ -215,9 +215,9 @@ inline Tensor irfftn( /// ``` inline Tensor hfft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_hfft_symint(self, n, dim, norm); } @@ -234,9 +234,9 @@ inline Tensor hfft( /// ``` inline Tensor ihfft( const Tensor& self, - std::optional n = c10::nullopt, + std::optional n = std::nullopt, int64_t dim = -1, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_ihfft_symint(self, n, dim, norm); } @@ -253,9 +253,9 @@ inline Tensor ihfft( /// ``` inline Tensor hfft2( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_hfft2(self, s, dim, norm); } @@ -273,9 +273,9 @@ inline Tensor hfft2( /// ``` inline Tensor ihfft2( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_ihfft2(self, s, dim, norm); } @@ -292,9 +292,9 @@ inline Tensor ihfft2( /// ``` inline Tensor hfftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_hfftn(self, s, dim, norm); } @@ -312,9 +312,9 @@ inline Tensor hfftn( /// ``` inline Tensor ihfftn( const Tensor& self, - at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef s = std::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = c10::nullopt) { + std::optional norm = std::nullopt) { return torch::fft_ihfftn(self, s, dim, norm); } @@ -364,7 +364,7 @@ inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { /// ``` inline Tensor fftshift( const Tensor& x, - at::OptionalIntArrayRef dim = c10::nullopt) { + at::OptionalIntArrayRef dim = std::nullopt) { return torch::fft_fftshift(x, dim); } @@ -381,7 +381,7 @@ inline Tensor fftshift( /// ``` inline Tensor ifftshift( const Tensor& x, - at::OptionalIntArrayRef dim = c10::nullopt) { + at::OptionalIntArrayRef dim = std::nullopt) { return torch::fft_ifftshift(x, dim); } diff --git a/torch/csrc/api/include/torch/nested.h b/torch/csrc/api/include/torch/nested.h index 780aab42304723..2e4365e0031cc0 100644 --- a/torch/csrc/api/include/torch/nested.h +++ b/torch/csrc/api/include/torch/nested.h @@ -26,7 +26,7 @@ inline at::Tensor nested_tensor( auto out = at::_nested_tensor_from_tensor_list( nested_tensor_data, c10::typeMetaToScalarType(options.dtype()), - c10::nullopt, + std::nullopt, options.device(), options.pinned_memory()); if (options.has_requires_grad() && options.requires_grad()) { @@ -55,7 +55,7 @@ inline at::Tensor nested_tensor( auto out = at::_nested_tensor_from_tensor_list( tensor_list, c10::typeMetaToScalarType(options.dtype()), - c10::nullopt, + std::nullopt, options.device(), options.pinned_memory()); if (options.has_requires_grad() && options.requires_grad()) { @@ -72,10 +72,10 @@ inline at::Tensor nested_tensor( /// ``` inline at::Tensor as_nested_tensor( at::TensorList list, - std::optional dtype = c10::nullopt, - std::optional device = c10::nullopt) { + std::optional dtype = std::nullopt, + std::optional device = std::nullopt) { return at::_nested_tensor_from_tensor_list( - list, dtype, c10::nullopt, device, c10::nullopt); + list, dtype, std::nullopt, device, std::nullopt); } /// Nested to padded tensor @@ -87,7 +87,7 @@ inline at::Tensor as_nested_tensor( inline at::Tensor to_padded_tensor( const at::Tensor& self, double padding, - at::OptionalIntArrayRef output_size = c10::nullopt) { + at::OptionalIntArrayRef output_size = std::nullopt) { return at::nested_to_padded_tensor(self, padding, output_size); } diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 89e596f71d143d..5ae6fcc317602a 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -236,7 +236,7 @@ inline Tensor softmax( std::optional dtype) { Tensor ret; - if (dtype == c10::nullopt) { + if (dtype == std::nullopt) { ret = input.softmax(dim); } else { ret = input.softmax(dim, dtype); @@ -273,7 +273,7 @@ inline Tensor softmin( std::optional dtype) { Tensor ret; - if (dtype == c10::nullopt) { + if (dtype == std::nullopt) { ret = (-input).softmax(dim); } else { ret = (-input).softmax(dim, dtype); @@ -310,7 +310,7 @@ inline Tensor log_softmax( std::optional dtype) { Tensor ret; - if (dtype == c10::nullopt) { + if (dtype == std::nullopt) { ret = input.log_softmax(dim); } else { ret = input.log_softmax(dim, dtype); diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h index b06b0a3dc1e851..602268ab2eba30 100644 --- a/torch/csrc/api/include/torch/nn/functional/embedding.h +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -31,7 +31,7 @@ inline Tensor embedding( bool sparse) { auto input_ = input; - if (padding_idx != c10::nullopt) { + if (padding_idx != std::nullopt) { if (*padding_idx > 0) { TORCH_CHECK( *padding_idx < weight.size(0), @@ -46,7 +46,7 @@ inline Tensor embedding( padding_idx = -1; } - if (max_norm != c10::nullopt) { + if (max_norm != std::nullopt) { input_ = input_.contiguous(); // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); @@ -149,7 +149,7 @@ inline Tensor embedding_bag( TORCH_CHECK(false, "mode has to be one of sum, mean or max"); } - if (max_norm != c10::nullopt) { + if (max_norm != std::nullopt) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); } diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index d1e285d0a0f19a..6a425e606caf26 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -346,7 +346,7 @@ inline Tensor smooth_l1_loss( const Tensor& input, const Tensor& target, SmoothL1LossFuncOptions::reduction_t reduction, - std::optional beta_opt = c10::nullopt) { + std::optional beta_opt = std::nullopt) { if (target.sizes() != input.sizes()) { TORCH_WARN( "Using a target size (", @@ -405,7 +405,7 @@ inline Tensor smooth_l1_loss( const SmoothL1LossFuncOptions& options, double beta) { TORCH_CHECK( - options.beta() == c10::nullopt, + options.beta() == std::nullopt, "expected beta not to be provided in 'options', but got ", options.beta().value()); return detail::smooth_l1_loss(input, target, options.reduction(), beta); diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h index 53bd61839f7451..965cfcd9ac83fa 100644 --- a/torch/csrc/api/include/torch/nn/functional/normalization.h +++ b/torch/csrc/api/include/torch/nn/functional/normalization.h @@ -17,7 +17,7 @@ inline Tensor normalize( int64_t dim, double eps, std::optional out) { - if (out == c10::nullopt) { + if (out == std::nullopt) { auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); return input / denom; } else { @@ -115,7 +115,7 @@ inline Tensor local_response_norm( /*padding=*/0, /*ceil_mode=*/false, /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt) + /*divisor_override=*/std::nullopt) .squeeze(1); } else { auto sizes = input.sizes(); @@ -132,7 +132,7 @@ inline Tensor local_response_norm( /*padding=*/0, /*ceil_mode=*/false, /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt) + /*divisor_override=*/std::nullopt) .squeeze(1); div = div.view(sizes); } diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h index be3009f62201a2..798467c0e0a681 100644 --- a/torch/csrc/api/include/torch/nn/functional/pooling.h +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -820,15 +820,15 @@ inline std::tuple fractional_max_pool2d_with_indices( const std::optional>& output_size, const std::optional>& output_ratio, const Tensor& _random_samples) { - if (output_size == c10::nullopt && output_ratio == c10::nullopt) { + if (output_size == std::nullopt && output_ratio == std::nullopt) { TORCH_CHECK( false, "fractional_max_pool2d requires specifying either ", "an output_size or an output_ratio"); } std::optional> output_size_ = output_size; - if (output_size_ == c10::nullopt) { - TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); + if (output_size_ == std::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); output_size_ = { (int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[0]), @@ -913,7 +913,7 @@ inline std::tuple fractional_max_pool3d_with_indices( const std::optional>& output_size, const std::optional>& output_ratio, const Tensor& _random_samples) { - if (output_size == c10::nullopt && output_ratio == c10::nullopt) { + if (output_size == std::nullopt && output_ratio == std::nullopt) { TORCH_CHECK( false, "fractional_max_pool3d requires specifying either ", @@ -921,8 +921,8 @@ inline std::tuple fractional_max_pool3d_with_indices( } std::optional> output_size_ = output_size; - if (output_size_ == c10::nullopt) { - TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); + if (output_size_ == std::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); output_size_ = { (int64_t)(static_cast(input.size(-3)) * (*output_ratio.value())[0]), @@ -1066,7 +1066,7 @@ inline Tensor lp_pool2d( /*padding=*/0, ceil_mode, /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt); + /*divisor_override=*/std::nullopt); return (torch::sign(out) * relu(torch::abs(out))) .mul(kw * kh) @@ -1116,7 +1116,7 @@ inline Tensor lp_pool3d( /*padding=*/0, ceil_mode, /*count_include_pad=*/true, - /*divisor_override=*/c10::nullopt); + /*divisor_override=*/std::nullopt); return (torch::sign(out) * relu(torch::abs(out))) .mul(kd * kw * kh) diff --git a/torch/csrc/api/include/torch/nn/functional/upsampling.h b/torch/csrc/api/include/torch/nn/functional/upsampling.h index 38c5c51f9a475e..75707ef091a783 100644 --- a/torch/csrc/api/include/torch/nn/functional/upsampling.h +++ b/torch/csrc/api/include/torch/nn/functional/upsampling.h @@ -19,13 +19,13 @@ inline std::vector _interp_output_size( std::optional>, std::optional> closed_over_args) { auto [input, size, scale_factor, recompute_scale_factor] = closed_over_args; - if (size == c10::nullopt && scale_factor == c10::nullopt) { + if (size == std::nullopt && scale_factor == std::nullopt) { TORCH_CHECK(false, "either size or scale_factor should be defined"); } - if (size != c10::nullopt && scale_factor != c10::nullopt) { + if (size != std::nullopt && scale_factor != std::nullopt) { TORCH_CHECK(false, "only one of size or scale_factor should be defined"); } - if (scale_factor != c10::nullopt) { + if (scale_factor != std::nullopt) { if (static_cast(scale_factor.value().size()) != dim) { TORCH_CHECK( false, @@ -36,14 +36,14 @@ inline std::vector _interp_output_size( torch::ArrayRef(*scale_factor)); } } - if (size != c10::nullopt) { + if (size != std::nullopt) { return *size; } - TORCH_INTERNAL_ASSERT(scale_factor != c10::nullopt); + TORCH_INTERNAL_ASSERT(scale_factor != std::nullopt); auto scale_factors = *scale_factor; - if (recompute_scale_factor == c10::nullopt) { + if (recompute_scale_factor == std::nullopt) { // only warn when the scales have floating values since // the result for ints is the same with/without recompute_scale_factor bool is_float_scale_factor = false; @@ -83,14 +83,14 @@ inline Tensor interpolate( bool antialias) { if (std::holds_alternative(mode) || std::get_if(&mode)) { - if (align_corners != c10::nullopt) { + if (align_corners != std::nullopt) { TORCH_CHECK( false, "align_corners option can only be set with the " "interpolating modes: linear | bilinear | bicubic | trilinear"); } } else { - if (align_corners == c10::nullopt) { + if (align_corners == std::nullopt) { TORCH_WARN( "Default upsampling behavior when mode=", enumtype::get_enum_name(mode), @@ -114,8 +114,8 @@ inline Tensor interpolate( auto scale_factor_len = input.dim() - 2; std::vector> scale_factor_list( - scale_factor_len, c10::nullopt); - if (scale_factor != c10::nullopt && !recompute_scale_factor.value_or(false)) { + scale_factor_len, std::nullopt); + if (scale_factor != std::nullopt && !recompute_scale_factor.value_or(false)) { auto _scale_factor_repeated = *scale_factor; scale_factor_list = {}; for (const auto& elem : _scale_factor_repeated) { @@ -181,7 +181,7 @@ inline Tensor interpolate( input, _interp_output_size(3, std::move(closed_over_args))); } else if (input.dim() == 3 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != c10::nullopt, "align_corners should be specified."); + align_corners != std::nullopt, "align_corners should be specified."); return torch::upsample_linear1d( input, _interp_output_size(1, std::move(closed_over_args)), @@ -195,7 +195,7 @@ inline Tensor interpolate( TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input"); } else if (input.dim() == 4 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != c10::nullopt, "align_corners should be specified."); + align_corners != std::nullopt, "align_corners should be specified."); if (antialias) { return torch::_upsample_bilinear2d_aa( input, @@ -218,7 +218,7 @@ inline Tensor interpolate( TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input"); } else if (input.dim() == 5 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != c10::nullopt, "align_corners should be specified."); + align_corners != std::nullopt, "align_corners should be specified."); return torch::upsample_trilinear3d( input, _interp_output_size(3, std::move(closed_over_args)), @@ -228,7 +228,7 @@ inline Tensor interpolate( scale_factor_list.at(2)); } else if (input.dim() == 4 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != c10::nullopt, "align_corners should be specified."); + align_corners != std::nullopt, "align_corners should be specified."); if (antialias) { return torch::_upsample_bicubic2d_aa( input, diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index ec76c6b4a6fbc6..0f5e32746936eb 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -106,7 +106,7 @@ class BatchNormImplBase : public NormImplBase { this->_check_input_dim(input); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double exponential_average_factor; - if (this->options.momentum() == c10::nullopt) { + if (this->options.momentum() == std::nullopt) { exponential_average_factor = 0.0; } else { exponential_average_factor = this->options.momentum().value(); @@ -116,7 +116,7 @@ class BatchNormImplBase : public NormImplBase { if (this->num_batches_tracked.defined()) { this->num_batches_tracked += 1; if (this->options.momentum() == - c10::nullopt) { // use cumulative moving average + std::nullopt) { // use cumulative moving average exponential_average_factor = 1.0 / this->num_batches_tracked.template item(); } else { // use exponential moving average diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index 9c55254ddb9103..e44fd44b954abe 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -350,7 +350,7 @@ class TORCH_API ConvTranspose1dImpl explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = c10::nullopt); + const std::optional& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) @@ -392,7 +392,7 @@ class TORCH_API ConvTranspose2dImpl explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = c10::nullopt); + const std::optional& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) @@ -434,7 +434,7 @@ class TORCH_API ConvTranspose3dImpl explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = c10::nullopt); + const std::optional& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) diff --git a/torch/csrc/api/include/torch/nn/modules/pooling.h b/torch/csrc/api/include/torch/nn/modules/pooling.h index 6bcdca463b1ba9..0fac60edbcde40 100644 --- a/torch/csrc/api/include/torch/nn/modules/pooling.h +++ b/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -507,7 +507,7 @@ class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = c10::nullopt); + const std::optional>& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) @@ -539,7 +539,7 @@ class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = c10::nullopt); + const std::optional>& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) @@ -571,7 +571,7 @@ class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = c10::nullopt); + const std::optional>& output_size = std::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h index 869027a241492d..6eaa0c1fb2c73e 100644 --- a/torch/csrc/api/include/torch/nn/modules/utils.h +++ b/torch/csrc/api/include/torch/nn/modules/utils.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include +#include #include diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index 165212e0e860cd..ac6cbc4ea4deab 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -252,7 +252,7 @@ struct TORCH_API SoftmaxFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = c10::nullopt; + TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional @@ -293,7 +293,7 @@ struct TORCH_API SoftminFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = c10::nullopt; + TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional @@ -334,7 +334,7 @@ struct TORCH_API LogSoftmaxFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = c10::nullopt; + TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional @@ -640,10 +640,10 @@ struct TORCH_API MultiheadAttentionOptions { /// add a new batch of zeros to the key and value sequences at dim=1. TORCH_ARG(bool, add_zero_attn) = false; - /// total number of features in key. Default: c10::nullopt. + /// total number of features in key. Default: std::nullopt. TORCH_ARG(int64_t, kdim); - /// total number of features in key. Default: c10::nullopt. + /// total number of features in key. Default: std::nullopt. TORCH_ARG(int64_t, vdim); }; diff --git a/torch/csrc/api/include/torch/nn/options/embedding.h b/torch/csrc/api/include/torch/nn/options/embedding.h index 20eacf90733552..a3d2fdb72f54da 100644 --- a/torch/csrc/api/include/torch/nn/options/embedding.h +++ b/torch/csrc/api/include/torch/nn/options/embedding.h @@ -28,10 +28,10 @@ struct TORCH_API EmbeddingOptions { /// Embedding, the embedding vector at `padding_idx` will default to all /// zeros, but can be updated to another value to be used as the padding /// vector. - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -55,10 +55,10 @@ struct TORCH_API EmbeddingFromPretrainedOptions { /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -84,10 +84,10 @@ struct TORCH_API EmbeddingFuncOptions { /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -120,7 +120,7 @@ struct TORCH_API EmbeddingBagOptions { TORCH_ARG(int64_t, embedding_dim); /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -148,7 +148,7 @@ struct TORCH_API EmbeddingBagOptions { /// zeros, but can be updated to another value to be used as the padding /// vector. Note that the embedding vector at `padding_idx` is excluded from /// the reduction. - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; }; // ============================================================================ @@ -161,7 +161,7 @@ struct TORCH_API EmbeddingBagFromPretrainedOptions { TORCH_ARG(bool, freeze) = true; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -184,7 +184,7 @@ struct TORCH_API EmbeddingBagFromPretrainedOptions { /// gradient; therefore, the embedding vector at padding_idx is not updated /// during training, i.e. it remains as a fixed "pad". Note that the embedding /// vector at `padding_idx` is excluded from the reduction. - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; }; // ============================================================================ @@ -205,7 +205,7 @@ struct TORCH_API EmbeddingBagFuncOptions { TORCH_ARG(torch::Tensor, offsets) = Tensor(); /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = c10::nullopt; + TORCH_ARG(std::optional, max_norm) = std::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -233,7 +233,7 @@ struct TORCH_API EmbeddingBagFuncOptions { /// gradient; therefore, the embedding vector at padding_idx is not updated /// during training, i.e. it remains as a fixed "pad". Note that the embedding /// vector at `padding_idx` is excluded from the reduction. - TORCH_ARG(std::optional, padding_idx) = c10::nullopt; + TORCH_ARG(std::optional, padding_idx) = std::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index f1fc7a4d411156..5a6e7aa3ab20be 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -451,7 +451,7 @@ struct TORCH_API TripletMarginWithDistanceLossOptions { /// closeness of two tensors. If not specified, `F::pairwise_distance` will /// be used. Default: nullopt TORCH_ARG(std::optional, distance_function) = - c10::nullopt; + std::nullopt; /// Specifies a nonnegative margin representing the minimum difference /// between the positive and negative distances required for the loss to be 0. /// Larger margins penalize cases where the negative examples are not distance @@ -548,7 +548,7 @@ struct TORCH_API SmoothL1LossOptions { /// Specifies the threshold at which to change between L1 and L2 loss. /// If beta is not specified, a value of 1.0 will be used. /// Default: nullopt - TORCH_ARG(std::optional, beta) = c10::nullopt; + TORCH_ARG(std::optional, beta) = std::nullopt; }; namespace functional { diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h index a1e5b1a0aeab1c..4b6dcd6ffe0c27 100644 --- a/torch/csrc/api/include/torch/nn/options/normalization.h +++ b/torch/csrc/api/include/torch/nn/options/normalization.h @@ -133,7 +133,7 @@ struct TORCH_API NormalizeFuncOptions { TORCH_ARG(double, eps) = 1e-12; /// the output tensor. If `out` is used, this /// operation won't be differentiable. - TORCH_ARG(std::optional, out) = c10::nullopt; + TORCH_ARG(std::optional, out) = std::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h index 8f6cee99bff6ae..75408890e7cd12 100644 --- a/torch/csrc/api/include/torch/nn/options/pooling.h +++ b/torch/csrc/api/include/torch/nn/options/pooling.h @@ -32,7 +32,7 @@ struct AvgPoolOptions { /// if specified, it will be used as divisor, otherwise size of the pooling /// region will be used. - TORCH_ARG(std::optional, divisor_override) = c10::nullopt; + TORCH_ARG(std::optional, divisor_override) = std::nullopt; }; /// `AvgPoolOptions` specialized for the `AvgPool1d` module. @@ -401,7 +401,7 @@ struct MaxUnpoolFuncOptions { TORCH_ARG(ExpandingArray, padding) = 0; /// the targeted output size - TORCH_ARG(std::optional>, output_size) = c10::nullopt; + TORCH_ARG(std::optional>, output_size) = std::nullopt; }; /// `MaxUnpoolFuncOptions` specialized for @@ -450,12 +450,12 @@ struct FractionalMaxPoolOptions { TORCH_ARG(ExpandingArray, kernel_size); /// the target output size of the image - TORCH_ARG(std::optional>, output_size) = c10::nullopt; + TORCH_ARG(std::optional>, output_size) = std::nullopt; /// If one wants to have an output size as a ratio of the input size, this /// option can be given. This has to be a number or tuple in the range (0, 1) using ExpandingArrayDouble = torch::ExpandingArray; - TORCH_ARG(std::optional, output_ratio) = c10::nullopt; + TORCH_ARG(std::optional, output_ratio) = std::nullopt; TORCH_ARG(torch::Tensor, _random_samples) = Tensor(); }; diff --git a/torch/csrc/api/include/torch/nn/options/upsampling.h b/torch/csrc/api/include/torch/nn/options/upsampling.h index 21df2b89998de5..df8eb194180acc 100644 --- a/torch/csrc/api/include/torch/nn/options/upsampling.h +++ b/torch/csrc/api/include/torch/nn/options/upsampling.h @@ -20,10 +20,10 @@ namespace nn { /// ``` struct TORCH_API UpsampleOptions { /// output spatial sizes. - TORCH_ARG(std::optional>, size) = c10::nullopt; + TORCH_ARG(std::optional>, size) = std::nullopt; /// multiplier for spatial size. - TORCH_ARG(std::optional>, scale_factor) = c10::nullopt; + TORCH_ARG(std::optional>, scale_factor) = std::nullopt; /// the upsampling algorithm: one of "nearest", "linear", "bilinear", /// "bicubic" and "trilinear". Default: "nearest" @@ -40,7 +40,7 @@ struct TORCH_API UpsampleOptions { /// aligned, and thus preserving the values at those pixels. This only has /// effect when :attr:`mode` is "linear", "bilinear", "bicubic", or /// "trilinear". Default: "False" - TORCH_ARG(std::optional, align_corners) = c10::nullopt; + TORCH_ARG(std::optional, align_corners) = std::nullopt; }; namespace functional { @@ -65,10 +65,10 @@ struct TORCH_API InterpolateFuncOptions { mode_t; /// output spatial sizes. - TORCH_ARG(std::optional>, size) = c10::nullopt; + TORCH_ARG(std::optional>, size) = std::nullopt; /// multiplier for spatial size. - TORCH_ARG(std::optional>, scale_factor) = c10::nullopt; + TORCH_ARG(std::optional>, scale_factor) = std::nullopt; /// the upsampling algorithm: one of "nearest", "linear", "bilinear", /// "bicubic", "trilinear", "area", "nearest-exact". Default: "nearest" @@ -83,7 +83,7 @@ struct TORCH_API InterpolateFuncOptions { /// this operation *independent* of input size when `scale_factor` is /// kept the same. It is *required* when interpolating mode is "linear", /// "bilinear", "bicubic" or "trilinear". Default: "False" - TORCH_ARG(std::optional, align_corners) = c10::nullopt; + TORCH_ARG(std::optional, align_corners) = std::nullopt; /// recompute the scale_factor for use in the /// interpolation calculation. When `scale_factor` is passed as a parameter, @@ -95,7 +95,7 @@ struct TORCH_API InterpolateFuncOptions { /// used in the interpolation computation. Note that when `scale_factor` is /// floating-point, the recomputed scale_factor may differ from the one passed /// in due to rounding and precision issues. - TORCH_ARG(std::optional, recompute_scale_factor) = c10::nullopt; + TORCH_ARG(std::optional, recompute_scale_factor) = std::nullopt; /// flag to apply anti-aliasing. Using anti-alias /// option together with :attr:`align_corners` equals "False", interpolation diff --git a/torch/csrc/api/include/torch/nn/options/vision.h b/torch/csrc/api/include/torch/nn/options/vision.h index c012b40d21f695..a5204f0dffb624 100644 --- a/torch/csrc/api/include/torch/nn/options/vision.h +++ b/torch/csrc/api/include/torch/nn/options/vision.h @@ -28,7 +28,7 @@ struct TORCH_API GridSampleFuncOptions { /// padding mode for outside grid values. Default: Zeros TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; /// Specifies perspective to pixel as point. Default: false - TORCH_ARG(std::optional, align_corners) = c10::nullopt; + TORCH_ARG(std::optional, align_corners) = std::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h index fbb533662c7be3..8a2a569c03335c 100644 --- a/torch/csrc/api/include/torch/nn/utils/clip_grad.h +++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -64,7 +64,7 @@ inline double clip_grad_norm_( // synchronizing the CPU and the gradients' device until the very end to // preserve async execution on the device. When checking for finite-ness, this // optional ensures we only sync once. - std::optional total_norm = c10::nullopt; + std::optional total_norm = std::nullopt; if (error_if_nonfinite) { total_norm = total_norm_tensor.item().toDouble(); TORCH_CHECK( @@ -79,7 +79,7 @@ inline double clip_grad_norm_( auto clip_coef = max_norm / (total_norm_tensor + 1e-6); auto clip_coef_clamped = - torch::clamp(clip_coef, c10::nullopt /* min */, 1.0 /* max */); + torch::clamp(clip_coef, std::nullopt /* min */, 1.0 /* max */); for (auto& param : params_with_grad) { param.grad().data().mul_(clip_coef_clamped); } diff --git a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h index 6f62d483c4d8b8..b8bfee33473f2a 100644 --- a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h +++ b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -15,7 +15,7 @@ inline std::optional _check_param_device( const torch::Tensor& param, std::optional old_param_device) { // Meet the first parameter - if (old_param_device == c10::nullopt) { + if (old_param_device == std::nullopt) { old_param_device = param.is_cuda() ? param.get_device() : -1; } else { bool warn = false; diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index 001b0cd33f2596..0832afff5f8f20 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -17,11 +17,11 @@ struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { LBFGSOptions(double lr = 1); TORCH_ARG(double, lr) = 1; TORCH_ARG(int64_t, max_iter) = 20; - TORCH_ARG(std::optional, max_eval) = c10::nullopt; + TORCH_ARG(std::optional, max_eval) = std::nullopt; TORCH_ARG(double, tolerance_grad) = 1e-7; TORCH_ARG(double, tolerance_change) = 1e-9; TORCH_ARG(int64_t, history_size) = 100; - TORCH_ARG(std::optional, line_search_fn) = c10::nullopt; + TORCH_ARG(std::optional, line_search_fn) = std::nullopt; public: void serialize(torch::serialize::InputArchive& archive) override; @@ -45,7 +45,7 @@ struct TORCH_API LBFGSParamState TORCH_ARG(std::deque, old_dirs); TORCH_ARG(std::deque, old_stps); TORCH_ARG(std::deque, ro); - TORCH_ARG(std::optional>, al) = c10::nullopt; + TORCH_ARG(std::optional>, al) = std::nullopt; public: void serialize(torch::serialize::InputArchive& archive) override; @@ -66,13 +66,13 @@ class TORCH_API LBFGS : public Optimizer { TORCH_CHECK( param_groups_.size() == 1, "LBFGS doesn't support per-parameter options (parameter groups)"); - if (defaults.max_eval() == c10::nullopt) { + if (defaults.max_eval() == std::nullopt) { auto max_eval_val = (defaults.max_iter() * 5) / 4; static_cast(param_groups_[0].options()) .max_eval(max_eval_val); static_cast(*defaults_.get()).max_eval(max_eval_val); } - _numel_cache = c10::nullopt; + _numel_cache = std::nullopt; } explicit LBFGS(std::vector params, LBFGSOptions defaults = {}) : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {} diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index 1f448e4fffd61c..dd5bd600ff3e79 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -186,22 +186,22 @@ class TORCH_API Optimizer { }; /* How do we decide whether to serialize undefined tensors or - c10::nullopt values into the output archive? + std::nullopt values into the output archive? Answer: we strictly follow the behavior of Python API. To be more specific: For optimizer options: a) For undefined tensor: currently no tensor is used as an options argument in -Python API, so we don't need to worry about it now. b) For c10::nullopt value: -we serialize c10::nullopt values into the output archive, to follow the exact +Python API, so we don't need to worry about it now. b) For std::nullopt value: +we serialize std::nullopt values into the output archive, to follow the exact same behavior as Python API. For optimizer param state: a) For undefined tensor: in param state, undefined tensor in C++ impl is equivalent to missing key in Python impl. Since we don't serialize missing keys in Python API, we skip undefined tensors when serializing the param state. b) -For c10::nullopt value: in param state, c10::nullopt value in C++ impl is +For std::nullopt value: in param state, std::nullopt value in C++ impl is equivalent to missing key in Python impl. Since we don't serialize missing keys -in Python API, we skip c10::nullopt values when serializing the param state. */ +in Python API, we skip std::nullopt values when serializing the param state. */ /// Serializes an `Optimizer` into an `OutputArchive`. TORCH_API serialize::OutputArchive& operator<<( diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index f77b34aad0bd43..3650cfcfea23f9 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -1,10 +1,10 @@ #pragma once #include -#include #include #include #include +#include #include #include @@ -76,27 +76,27 @@ class TORCH_API InputArchive final { /// is not specified, the module is loaded to the original device. void load_from( const std::string& filename, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); /// Loads the `InputArchive` from a serialized representation stored in the /// given `stream`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. void load_from( std::istream& stream, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); // Loads given the specified flat array. void load_from( const char* data, size_t size, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); // Loads given the specified read and size functions. void load_from( const std::function& read_func, const std::function& size_func, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); // Returns the vector of keys in the input archive. std::vector keys(); diff --git a/torch/csrc/api/include/torch/types.h b/torch/csrc/api/include/torch/types.h index 8a23cd122b8d1d..febda7ac6bb852 100644 --- a/torch/csrc/api/include/torch/types.h +++ b/torch/csrc/api/include/torch/types.h @@ -2,7 +2,7 @@ #include -#include +#include #include #include @@ -38,7 +38,7 @@ namespace torch { // the `func()` function defined in `at::` namespace is always hidden. using namespace at; // NOLINT -using c10::nullopt; +using std::nullopt; using std::optional; using Dtype = at::ScalarType; diff --git a/torch/csrc/api/src/jit.cpp b/torch/csrc/api/src/jit.cpp index 16d9d0040a6592..07064dbdc9e786 100644 --- a/torch/csrc/api/src/jit.cpp +++ b/torch/csrc/api/src/jit.cpp @@ -11,7 +11,7 @@ namespace jit { std::shared_ptr compile(const std::string& source) { auto module = std::make_shared(); - module->define(c10::nullopt, source, nativeResolver(), nullptr); + module->define(std::nullopt, source, nativeResolver(), nullptr); return module; } diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 56218ad091de5d..518072d0653f12 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -130,7 +130,7 @@ void SoftmaxImpl::pretty_print(std::ostream& stream) const { } Tensor SoftmaxImpl::forward(const Tensor& input) { - return F::detail::softmax(input, options.dim(), c10::nullopt); + return F::detail::softmax(input, options.dim(), std::nullopt); } // ============================================================================ @@ -144,7 +144,7 @@ void SoftminImpl::pretty_print(std::ostream& stream) const { } Tensor SoftminImpl::forward(const Tensor& input) { - return F::detail::softmin(input, options.dim(), c10::nullopt); + return F::detail::softmin(input, options.dim(), std::nullopt); } // ============================================================================ @@ -159,7 +159,7 @@ void LogSoftmaxImpl::pretty_print(std::ostream& stream) const { } Tensor LogSoftmaxImpl::forward(const Tensor& input) { - return F::detail::log_softmax(input, options.dim(), c10::nullopt); + return F::detail::log_softmax(input, options.dim(), std::nullopt); } // ============================================================================ @@ -174,7 +174,7 @@ Tensor Softmax2dImpl::forward(const Tensor& input) { TORCH_CHECK( input.dim() == 4 || input.dim() == 3, "Softmax2d requires a 3D or 4D tensor as input"); - return F::detail::softmax(input, /*dim=*/-3, c10::nullopt); + return F::detail::softmax(input, /*dim=*/-3, std::nullopt); } // ============================================================================ diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 197c3cf0725cd0..26e52df637f852 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -176,7 +176,7 @@ std::vector ConvTransposeNdImpl::_output_padding( std::vector ret; std::optional output_size_ = output_size; - if (output_size_ == c10::nullopt) { + if (output_size_ == std::nullopt) { ret = at::IntArrayRef(this->options.output_padding()).vec(); } else { auto k = input.dim() - 2; diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 553a93875e1784..4c6683d1f36b58 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -20,7 +20,7 @@ EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options_) } void EmbeddingImpl::reset() { - if (options.padding_idx() != c10::nullopt) { + if (options.padding_idx() != std::nullopt) { if (*options.padding_idx() > 0) { TORCH_CHECK( *options.padding_idx() < options.num_embeddings(), @@ -50,7 +50,7 @@ void EmbeddingImpl::reset() { void EmbeddingImpl::reset_parameters() { torch::nn::init::normal_(weight); - if (options.padding_idx() != c10::nullopt) { + if (options.padding_idx() != std::nullopt) { torch::NoGradGuard no_grad; weight[*options.padding_idx()].fill_(0); } @@ -59,10 +59,10 @@ void EmbeddingImpl::reset_parameters() { void EmbeddingImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Embedding(num_embeddings=" << options.num_embeddings() << ", embedding_dim=" << options.embedding_dim(); - if (options.padding_idx() != c10::nullopt) { + if (options.padding_idx() != std::nullopt) { stream << ", padding_idx=" << *options.padding_idx(); } - if (options.max_norm() != c10::nullopt) { + if (options.max_norm() != std::nullopt) { stream << ", max_norm=" << *options.max_norm(); } if (options.norm_type() != 2) { @@ -154,7 +154,7 @@ void EmbeddingBagImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::EmbeddingBag(num_embeddings=" << options.num_embeddings() << ", embedding_dim=" << options.embedding_dim(); - if (options.max_norm() != c10::nullopt) { + if (options.max_norm() != std::nullopt) { stream << ", max_norm=" << *options.max_norm(); } if (options.norm_type() != 2) { diff --git a/torch/csrc/api/src/nn/modules/pooling.cpp b/torch/csrc/api/src/nn/modules/pooling.cpp index 0b11b914dcc1c7..a02d8cd712aa09 100644 --- a/torch/csrc/api/src/nn/modules/pooling.cpp +++ b/torch/csrc/api/src/nn/modules/pooling.cpp @@ -281,19 +281,19 @@ FractionalMaxPool2dImpl::FractionalMaxPool2dImpl( void FractionalMaxPool2dImpl::reset() { _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == c10::nullopt && - options.output_ratio() == c10::nullopt) { + if (options.output_size() == std::nullopt && + options.output_ratio() == std::nullopt) { TORCH_CHECK( false, "FractionalMaxPool2d requires specifying either ", "an output size, or a pooling ratio"); } - if (options.output_size() != c10::nullopt && - options.output_ratio() != c10::nullopt) { + if (options.output_size() != std::nullopt && + options.output_ratio() != std::nullopt) { TORCH_CHECK( false, "only one of output_size and output_ratio may be specified"); } - if (options.output_ratio() != c10::nullopt) { + if (options.output_ratio() != std::nullopt) { at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && @@ -340,19 +340,19 @@ FractionalMaxPool3dImpl::FractionalMaxPool3dImpl( void FractionalMaxPool3dImpl::reset() { _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == c10::nullopt && - options.output_ratio() == c10::nullopt) { + if (options.output_size() == std::nullopt && + options.output_ratio() == std::nullopt) { TORCH_CHECK( false, "FractionalMaxPool3d requires specifying either ", "an output size, or a pooling ratio"); } - if (options.output_size() != c10::nullopt && - options.output_ratio() != c10::nullopt) { + if (options.output_size() != std::nullopt && + options.output_ratio() != std::nullopt) { TORCH_CHECK( false, "only one of output_size and output_ratio may be specified"); } - if (options.output_ratio() != c10::nullopt) { + if (options.output_ratio() != std::nullopt) { at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && diff --git a/torch/csrc/api/src/nn/modules/upsampling.cpp b/torch/csrc/api/src/nn/modules/upsampling.cpp index 8e7bb2fe33cd84..378d5aadb92031 100644 --- a/torch/csrc/api/src/nn/modules/upsampling.cpp +++ b/torch/csrc/api/src/nn/modules/upsampling.cpp @@ -15,7 +15,7 @@ void UpsampleImpl::reset() {} void UpsampleImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Upsample("; - if (options.scale_factor() != c10::nullopt) { + if (options.scale_factor() != std::nullopt) { stream << "scale_factor=" << at::ArrayRef(*options.scale_factor()); } else { stream << "size=" << at::ArrayRef(*options.size()); @@ -43,7 +43,7 @@ Tensor UpsampleImpl::forward(const Tensor& input) { options.scale_factor(), mode, options.align_corners(), - c10::nullopt, + std::nullopt, false); } diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index 10739be6238697..dbf17f718614a0 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -68,7 +68,7 @@ bool if_container_equal(T lhs, T rhs) { bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) { auto isNull = [](const std::optional>& val) { - return val == c10::nullopt; + return val == std::nullopt; }; return (lhs.func_evals() == rhs.func_evals()) && (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) && @@ -97,7 +97,7 @@ void LBFGSParamState::serialize( _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro); // Python version only serializes state vars if explicitly defined - if (al() != c10::nullopt) { + if (al() != std::nullopt) { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al); } } @@ -131,7 +131,7 @@ Tensor LBFGS::_gather_flat_grad() { } int64_t LBFGS::_numel() { - if (_numel_cache == c10::nullopt) { + if (_numel_cache == std::nullopt) { auto res = 0; for (const auto& p : param_groups_.at(0).params()) { res += p.numel(); @@ -194,12 +194,12 @@ static double _cubic_interpolate( double x2, double f2, double g2, - std::optional> bounds = c10::nullopt) { + std::optional> bounds = std::nullopt) { // ported from https://github.com/torch/optim/blob/master/polyinterp.lua // Compute bounds of interpolation area // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double xmin_bound, xmax_bound; - if (bounds != c10::nullopt) { + if (bounds != std::nullopt) { std::tie(xmin_bound, xmax_bound) = *bounds; } else { std::tie(xmin_bound, xmax_bound) = @@ -509,7 +509,7 @@ Tensor LBFGS::step(LossClosure closure) { // multiplied by the gradient int64_t num_old = static_cast(old_dirs.size()); - if (state.al() == c10::nullopt) { + if (state.al() == std::nullopt) { state.al(std::vector(history_size)); } auto& al = state.al(); @@ -557,7 +557,7 @@ Tensor LBFGS::step(LossClosure closure) { // optional line search: user function auto ls_func_evals = 0; - if (line_search_fn != c10::nullopt) { + if (line_search_fn != std::nullopt) { TORCH_CHECK( *line_search_fn == "strong_wolfe", "only 'strong_wolfe' is supported"); @@ -627,7 +627,7 @@ void LBFGS::load(serialize::InputArchive& archive) { TORCH_WARN( "Your serialized LBFGS optimizer is still using the old serialization format. " "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque " - "and al will be set to c10::nullopt because the old LBFGS optimizer didn't save these values." + "and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values." "You should re-save your LBFGS optimizer to use the new serialization format."); Tensor d, t, H_diag, prev_flat_grad, prev_loss; std::deque old_dirs, old_stps; diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 852f4eab1b52b1..8644b6193e0be8 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -93,20 +93,20 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { void InputArchive::load_from( const std::string& filename, - std::optional device /*= c10::nullopt*/) { + std::optional device /*= std::nullopt*/) { module_ = torch::jit::load(filename, std::move(device)); } void InputArchive::load_from( std::istream& stream, - std::optional device /*= c10::nullopt*/) { + std::optional device /*= std::nullopt*/) { module_ = torch::jit::load(stream, std::move(device)); } void InputArchive::load_from( const char* data, size_t size, - std::optional device /*= c10::nullopt*/) { + std::optional device /*= std::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { public: @@ -136,7 +136,7 @@ void InputArchive::load_from( void InputArchive::load_from( const std::function& read_func, const std::function& size_func, - std::optional device /*= c10::nullopt*/) { + std::optional device /*= std::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { public: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 9d897c667c906f..7ca1a172096817 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -630,7 +630,7 @@ Tensor div_tensor_self_backward( T other, ScalarType self_st) { return div_tensor_self_backward( - grad, std::move(other), self_st, c10::nullopt); + grad, std::move(other), self_st, std::nullopt); } template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType); template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType); @@ -652,7 +652,7 @@ Tensor div_tensor_other_backward( const Tensor& grad, const Tensor& self, const Tensor& other) { - return div_tensor_other_backward(grad, self, other, c10::nullopt); + return div_tensor_other_backward(grad, self, other, std::nullopt); } Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { @@ -1282,12 +1282,12 @@ Tensor convolution_jvp( at::SymIntArrayRef output_padding, const c10::SymInt& groups) { auto bias_t_opt = - bias_t.defined() ? std::optional(bias_t) : c10::nullopt; + bias_t.defined() ? std::optional(bias_t) : std::nullopt; return ( at::convolution_symint( input_t, weight_p, - c10::nullopt, + std::nullopt, stride, padding, dilation, @@ -1324,12 +1324,12 @@ Tensor _convolution_jvp( bool cudnn_enabled, bool allow_tf32) { auto bias_t_opt = - bias_t.defined() ? std::optional(bias_t) : c10::nullopt; + bias_t.defined() ? std::optional(bias_t) : std::nullopt; return ( at::_convolution_symint( input_t, weight_p, - c10::nullopt, + std::nullopt, stride, padding, dilation, @@ -6193,7 +6193,7 @@ Tensor batch_norm_jvp( std::optional result_p = weight_p.defined() ? std::optional((input_p - mean_p) * invstd_p) - : c10::nullopt; + : std::nullopt; return _affine_jvp( result_p, result_t, @@ -6232,7 +6232,7 @@ Tensor layer_norm_jvp( std::optional result_p = weight_p.defined() ? std::optional((input_p - mean_p) * invstd_p) - : c10::nullopt; + : std::nullopt; return _affine_jvp( result_p, result_t, @@ -6273,7 +6273,7 @@ Tensor group_norm_jvp( /*eps=*/0) .view(input_shape); - std::optional result_p = c10::nullopt; + std::optional result_p = std::nullopt; if (weight_p.defined()) { std::vector view_size(input_t_reshaped.dim(), 1); view_size[1] = input_t_reshaped.size(1); @@ -6706,7 +6706,7 @@ std::tuple _cudnn_convolution_backward( grad_output, self, weight, - c10::nullopt, + std::nullopt, stride, padding, dilation, @@ -6956,7 +6956,7 @@ Tensor to_sparse_backward( if (self_layout == c10::kStrided) { return grad.to_dense(); } else { - OptionalIntArrayRef blocksize = c10::nullopt; + OptionalIntArrayRef blocksize = std::nullopt; if (self_blocksize.has_value()) { blocksize = c10::asIntArrayRefSlowOpt(*self_blocksize); } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index dedff70be1ba34..3c461dd88ee56a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -39,7 +39,7 @@ TORCH_API inline std::optional wrap_opt_if( const Tensor& t, const bool cond) { using OptTensor = std::optional; - return cond ? OptTensor(t) : static_cast(c10::nullopt); + return cond ? OptTensor(t) : static_cast(std::nullopt); } TORCH_API Tensor diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 46e4014d8dd139..1473058a3a53df 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -1,11 +1,11 @@ #include #include #include -#include #include #include #include #include +#include using namespace at; diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 20f66694677e82..92096dca9a6989 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -11,6 +10,7 @@ #include #include #include +#include #include diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index d5fe8a70dae177..3b598898f80c4a 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -217,7 +217,7 @@ inline at::Tensor as_view( tensor, diff_view_meta->get_backward_view().chain( base, tensor, std::move(view_func), std::move(rev_view_func)), - c10::nullopt, + std::nullopt, /*shared_view_info*/ true, creation_meta, allow_tensor_metadata_change); @@ -225,7 +225,7 @@ inline at::Tensor as_view( return make_variable_differentiable_view( tensor, ViewInfo(base, std::move(view_func), std::move(rev_view_func)), - c10::nullopt, + std::nullopt, /*shared_view_info*/ true, creation_meta, allow_tensor_metadata_change); diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index 94ee179225a4ca..bd5d4a462102b2 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -47,7 +47,7 @@ namespace torch::autograd { TORCH_API void backward( const variable_list& tensors, const variable_list& grad_tensors = {}, - std::optional retain_graph = c10::nullopt, + std::optional retain_graph = std::nullopt, bool create_graph = false, const variable_list& inputs = {}); @@ -81,7 +81,7 @@ TORCH_API variable_list grad( const variable_list& outputs, const variable_list& inputs, const variable_list& grad_outputs = {}, - std::optional retain_graph = c10::nullopt, + std::optional retain_graph = std::nullopt, bool create_graph = false, bool allow_unused = false); diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index eff2a27c105f36..f922c3fc763260 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -345,7 +345,7 @@ static void autogradNotImplementedFallbackImpl( [&](size_t idx, size_t _, const at::Tensor& t) { storage_saved.push_back( t.has_storage() ? std::optional(t.storage()) - : c10::nullopt); + : std::nullopt); impl_saved.push_back(t.getIntrusivePtr()); }, &stack_args_copy, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index cfacbf0e3be7fe..cb9f5caca0eef1 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -735,10 +735,10 @@ void GraphTask::exec_post_processing() { for (const auto& leaf_stream : leaf_streams) { // stash_current_cuda/privateuse1_streams() stashed streams for all device // IDs that already had a CUDA/privateuse1 context before the GraphTask - // executed. For inactive devices, it stashed a c10::nullopt. I don't + // executed. For inactive devices, it stashed a std::nullopt. I don't // expect GraphTask's backward pass ran leaf nodes on any new devices, so // the stashed streams should be enough. If leaf_stream.device_index() - // happens to be for a new device, operator* on the c10::nullopt should + // happens to be for a new device, operator* on the std::nullopt should // throw an error. const auto caller_current_stream = // NOLINTNEXTLINE(bugprone-unchecked-optional-access) @@ -1554,7 +1554,7 @@ void GraphTask::stash_current_streams() { idx)) { caller_current_streams_[idx] = guard.getStream({accelerator, idx}); } else { - caller_current_streams_[idx] = c10::nullopt; + caller_current_streams_[idx] = std::nullopt; } } } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index c8c3538a061f17..4f7f53c90ec1ed 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -242,14 +242,14 @@ struct TORCH_API Node : std::enable_shared_from_this { std::optional stream() { auto opt_device_type = at::getAccelerator(); if (!opt_device_type.has_value()) { - return c10::nullopt; + return std::nullopt; } for (const auto& metadata : input_metadata_) { if (metadata.device().type() == opt_device_type.value()) return metadata.stream(); } - return c10::nullopt; + return std::nullopt; } void clear_input_metadata() { diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 2efde9d5f2f2e6..99597a73762ff7 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -224,7 +224,7 @@ struct TORCH_API AccumulateGrad : public Node { // variable_grad += new_grad; // } else { // result = at::empty_strided(variable.sizes(), variable.strides(), - // variable.options().memory_format(c10::nullopt)); + // variable.options().memory_format(std::nullopt)); // update_grad(at::native::add_out(result, variable_grad, // new_grad, 1.0); // } diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index 1aed18cb79a5ee..5093f51e7eff88 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -105,7 +105,7 @@ variable_list Gather::apply(variable_list&& inputs) { std::move(source_devices), std::move(input_sizes), dim_, - /*streams=*/c10::nullopt, + /*streams=*/std::nullopt, /*unsqueeze_scalars=*/unsqueeze_scalars); grad_fn->set_next_edges(collect_next_edges(inputs)); } diff --git a/torch/csrc/autograd/functions/comm.h b/torch/csrc/autograd/functions/comm.h index 0924cd030fcef8..2730827a1eb3c4 100644 --- a/torch/csrc/autograd/functions/comm.h +++ b/torch/csrc/autograd/functions/comm.h @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include @@ -17,10 +17,10 @@ namespace autograd { struct TORCH_CUDA_CU_API Scatter : public Node { explicit Scatter( std::vector devices, - std::optional> chunk_sizes = c10::nullopt, + std::optional> chunk_sizes = std::nullopt, int64_t dim = 0, std::optional>> streams = - c10::nullopt, + std::nullopt, bool unsqueeze_scalars = false); ~Scatter() override; diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index e6a907ee2f0a40..b22199ee1ad696 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1084,7 +1084,7 @@ static PyObject* push_on_torch_dispatch_stack( using c10::impl::TorchDispatchModeKey; // When we push a mode onto the mode stack, we need to // check if it's an "infra" mode, by checking its _mode_key attribute. - std::optional mode_key = c10::nullopt; + std::optional mode_key = std::nullopt; py::object maybe_mode_key_obj = PyObject_FastGetAttrString(arg, "_mode_key"); if (maybe_mode_key_obj) { @@ -1108,7 +1108,7 @@ static PyObject* pop_torch_dispatch_stack( PyObject* _unused, PyObject* maybe_mode_key) { HANDLE_TH_ERRORS - std::optional mode_key = c10::nullopt; + std::optional mode_key = std::nullopt; PyObject* r = nullptr; if (maybe_mode_key != Py_None) { mode_key = py::cast(maybe_mode_key); @@ -1174,7 +1174,7 @@ static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { auto mode_key = py::cast(arg); auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key); - if (maybe_mode == c10::nullopt) { + if (maybe_mode == std::nullopt) { Py_RETURN_NONE; } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) @@ -1190,7 +1190,7 @@ static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) { auto mode_key = py::cast(arg); const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key); - if (maybe_mode == c10::nullopt) { + if (maybe_mode == std::nullopt) { Py_RETURN_NONE; } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 6c12bbadc5d2d5..f2b08e364318a4 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include @@ -159,7 +159,7 @@ void InputBuffer::add( // Accumulation happens on the var device's default stream. TORCH_INTERNAL_ASSERT(device_of(var)); - std::optional opt_accumulate_stream = c10::nullopt; + std::optional opt_accumulate_stream = std::nullopt; const auto device_type = device_of(var).value().type(); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) if (device_of(var)->is_cuda() || device_of(var)->is_privateuseone()) { @@ -179,7 +179,7 @@ void InputBuffer::add( record_stream_any_impl(var, *opt_accumulate_stream); } } else { - std::optional opt_sync_stream = c10::nullopt; + std::optional opt_sync_stream = std::nullopt; const auto guard = c10::impl::VirtualGuardImpl{device_type}; if (on_consumer && !on_producer) { // (3a) diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index 7e471ef528bb03..e445ef897fc1aa 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -9,8 +9,8 @@ #include #include -#include #include +#include namespace torch::autograd { diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index b9387479667e86..53a24eaa150dbe 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -122,7 +122,7 @@ using torch::profiler::impl::ProfilerStateBase; struct ProfilerLegacyThreadLocalState : public ProfilerStateBase { explicit ProfilerLegacyThreadLocalState( const torch::profiler::impl::ProfilerConfig& config) - : ProfilerStateBase(config), remoteProfiledEvents_{c10::nullopt} {} + : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {} ~ProfilerLegacyThreadLocalState() override = default; static ProfilerLegacyThreadLocalState* getTLS() { diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 9bd88b0b3dc51e..59198129b2b278 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -336,7 +336,7 @@ TORCH_API void enableProfilerLegacy( using thread_event_lists = std::vector>; TORCH_API thread_event_lists disableProfilerLegacy( std::optional profilerDisableOptions = - c10::nullopt); + std::nullopt); // adds profiledEvents to the current thread local recorded events. Each event // will be marked with node ID given by fromNodeId. @@ -377,9 +377,9 @@ struct TORCH_API TLSLegacyProfilerGuard { explicit TLSLegacyProfilerGuard( const torch::profiler::impl::ProfilerConfig& cfg, std::optional> - resultCallback = c10::nullopt, + resultCallback = std::nullopt, std::optional profilerDisableOptions = - c10::nullopt) + std::nullopt) : cb_(std::move(resultCallback)), profilerDisableOptions_(profilerDisableOptions) { enableProfilerLegacy(cfg); diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 5fcc7b86a2fab8..e930faa1fdebe4 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -29,6 +28,7 @@ #include #include #include +#include namespace py = pybind11; @@ -349,7 +349,7 @@ TensorMetadata toTensorMetadata(PyObject* self) { std::optional ValueCache::recordIfTensor(py::handle p) { return THPVariable_CheckExact(p.ptr()) ? std::optional{toTensorMetadata(p.ptr())} - : c10::nullopt; + : std::nullopt; } std::vector> ValueCache::unpackTensorMap( @@ -379,7 +379,7 @@ void ValueCache::store(const PyCallKey& key, no_ephemeral_t) { template <> ExtraFields::args_t ValueCache::load( const PyCallKey& key) const { - return {std::get(state_).at(key), c10::nullopt}; + return {std::get(state_).at(key), std::nullopt}; } template <> @@ -419,7 +419,7 @@ ExtraFields::args_t ValueCache::load( return { /*frame_state_=*/std::get(state_).at(*cache.location_), /*module_info_=*/std::move(info), - /*optimizer_info_=*/c10::nullopt}; + /*optimizer_info_=*/std::nullopt}; } template <> @@ -465,7 +465,7 @@ ExtraFields::args_t ValueCache::load< return { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) /*frame_state_=*/std::get(state_).at(*cache.location_), - /*module_info_=*/c10::nullopt, + /*module_info_=*/std::nullopt, /*optimizer_info_=*/std::move(info)}; } diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 0227229d1f7fb9..a5ba07b2cdb53a 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -778,7 +778,7 @@ static void _get_tensors_to_save( for (const auto i : c10::irange(num_saved)) { PyObject* obj = PyTuple_GET_ITEM(self->to_save, i); if (obj == Py_None) { - tensors_to_save.emplace_back(c10::nullopt); + tensors_to_save.emplace_back(std::nullopt); continue; } else if (THPVariable_Check(obj)) { const auto& tensor = THPVariable_Unpack(obj); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index c2744f365476f0..0bf3c8bbab70b7 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 65f4b0efd3c188..94596c32a705e3 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -347,7 +347,7 @@ bool isResurrectable(THPVariable* self) { // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false) != - c10::make_optional((PyObject*)self)) { + std::make_optional((PyObject*)self)) { return false; } return true; @@ -455,7 +455,7 @@ static int THPVariable_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false) == - c10::make_optional((PyObject*)self)) { + std::make_optional((PyObject*)self)) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -587,14 +587,14 @@ static PyObject* view_func_impl( auto& view_func = view_info.view_fn(); // Determine new SymInt / tensor state as needed. - std::optional> new_symints = c10::nullopt; + std::optional> new_symints = std::nullopt; if (symint_visitor_fn != Py_None) { new_symints = map_py_func( py::cast(symint_visitor_fn), view_func.get_symints()); } - std::optional> new_tensors = c10::nullopt; + std::optional> new_tensors = std::nullopt; if (tensor_visitor_fn != Py_None) { new_tensors = map_py_func( py::cast(tensor_visitor_fn), diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index fdcafd6cd70910..e9b40b0dc8f75c 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -100,7 +100,7 @@ static inline Variable sequenceToVariable( c10::TensorOptions options, PyObject* seq) { return torch::utils::indexing_tensor_from_data( - options, kLong, c10::nullopt, seq); + options, kLong, std::nullopt, seq); } inline Variable valueToTensor( @@ -201,7 +201,7 @@ static inline Variable applySlicing( // as null may need to be changed after we reach a better solution for // nested tensor size std::optional result_sizes = result.is_nested() - ? std::optional(c10::nullopt) + ? std::optional(std::nullopt) : std::optional(result.sym_sizes()); result = at::indexing::handleDimInMultiDimIndexing( /*prev_dim_result=*/result, diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h index a145523c1bf8a5..a84d47c5b4829f 100644 --- a/torch/csrc/autograd/record_function_ops.h +++ b/torch/csrc/autograd/record_function_ops.h @@ -1,7 +1,7 @@ #pragma once #include -#include #include +#include namespace torch::autograd::profiler { @@ -17,7 +17,7 @@ struct PythonRecordFunction : public torch::CustomClassHolder { // callbacks. TORCH_API c10::intrusive_ptr record_function_enter_new( const std::string& name, - const std::optional& args = c10::nullopt); + const std::optional& args = std::nullopt); // Schedules RecordFunction's end callbacks to be run on completion of a future. TORCH_API c10::intrusive_ptr _call_end_callbacks_on_fut_new( diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 1dad10663dd70b..7189e02047251d 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -67,7 +67,7 @@ inline at::Tensor clone_obey_contract( .new_empty_strided_symint( variable.sym_sizes(), variable.sym_strides(), - variable.options().memory_format(c10::nullopt)) + variable.options().memory_format(std::nullopt)) .copy_(new_grad)); } else { // (2) diff --git a/torch/csrc/autograd/utils/python_arg_parsing.h b/torch/csrc/autograd/utils/python_arg_parsing.h index 326221e44d147a..e3fd671fb57cf9 100644 --- a/torch/csrc/autograd/utils/python_arg_parsing.h +++ b/torch/csrc/autograd/utils/python_arg_parsing.h @@ -31,7 +31,7 @@ parse_to_conversion(PythonArgs& r, bool allow_copy) { if (!allow_copy && !r.isNone(2)) throw std::runtime_error(".to() does not accept copy argument"); return std::make_tuple( - c10::nullopt, + std::nullopt, r.scalartype(0), r.toBool(1), r.toBool(2), diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index d60f37085f3808..2ce91146dc8d06 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -351,8 +351,8 @@ struct TORCH_API ViewFunc { /// Returns a clone of this ViewFunc, optionally with the specified saved /// state. virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const = 0; + std::optional> = std::nullopt, + std::optional> = std::nullopt) const = 0; protected: /// Sets the values of any SymInts in the saved state. The input vector size @@ -382,8 +382,8 @@ struct ChainedViewFunc : public ViewFunc { } virtual at::Tensor operator()(const at::Tensor&) const override; virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const override; + std::optional> = std::nullopt, + std::optional> = std::nullopt) const override; private: std::unique_ptr first; @@ -398,8 +398,8 @@ struct ErroringViewFunc : public ViewFunc { TORCH_CHECK(false, error_msg); } virtual std::unique_ptr clone_and_set( - std::optional> = c10::nullopt, - std::optional> = c10::nullopt) const override { + std::optional> = std::nullopt, + std::optional> = std::nullopt) const override { return std::make_unique(error_msg); } diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index d8f968eae5f5cb..52331909fe1dc2 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -11,9 +11,9 @@ #include #include #include -#include #include #include +#include #include #include diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h index 92009a1c40ada5..860629bcf2e9a3 100644 --- a/torch/csrc/cuda/comm.h +++ b/torch/csrc/cuda/comm.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include #include @@ -29,15 +29,15 @@ TORCH_CUDA_CU_API std::vector& scatter_out( std::vector& out_tensors, int64_t dim = 0, const std::optional>>& - streams = c10::nullopt); + streams = std::nullopt); TORCH_CUDA_CU_API std::vector scatter( const at::Tensor& tensor, at::IntArrayRef devices, - const std::optional>& chunk_sizes = c10::nullopt, + const std::optional>& chunk_sizes = std::nullopt, int64_t dim = 0, const std::optional>>& - streams = c10::nullopt); + streams = std::nullopt); TORCH_CUDA_CU_API at::Tensor& gather_out( at::TensorList tensors, diff --git a/torch/csrc/cuda/memory_snapshot.h b/torch/csrc/cuda/memory_snapshot.h index eb22767a78f905..fe5699af416012 100644 --- a/torch/csrc/cuda/memory_snapshot.h +++ b/torch/csrc/cuda/memory_snapshot.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include +#include #include namespace torch::cuda { diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 37d1be15cbd701..6561ccb6e76c1a 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -2,9 +2,9 @@ #include #include -#include #include +#include #include // NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 5060f9289a9e14..f62311efbd9361 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -60,7 +60,7 @@ static std::vector> unpack_streams( PyObject* obj, size_t size) { if (obj == Py_None) { - return std::vector>(size, c10::nullopt); + return std::vector>(size, std::nullopt); } auto streams = THPUtils_PySequence_to_CUDAStreamList(obj); if (streams.size() != size) { diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 062a15da4964c0..d37e695c77194a 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -98,7 +98,7 @@ void DistEngine::globalCpuThread( InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { - inputs.add(i, std::move(variables[i]), c10::nullopt, c10::nullopt); + inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)), diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 6507fe6abc2a2b..98af0d51a3d050 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -18,7 +18,7 @@ namespace c10d { ncclComm_t NCCLComm::getNcclComm() { std::unique_lock lock(mutex_); if (aborted_) { - auto commFailureMsg = commFailureReason_ != c10::nullopt + auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) : ""; TORCH_CHECK_WITH( @@ -76,7 +76,7 @@ std::shared_ptr NCCLComm::split( C10D_NCCL_CHECK( ncclCommSplit( source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), - c10::nullopt); + std::nullopt); ++source->ncclCommSplitCounter_; comm->rank_ = rank; return comm; @@ -186,11 +186,11 @@ std::string ncclGetErrorWithVersion(ncclResult_t error) { // thrown in the NCCL codebase. std::string getNcclErrorDetailStr( ncclResult_t error, - std::optional processGroupFailureReason /* = c10::nullopt */ + std::optional processGroupFailureReason /* = std::nullopt */ ) { // Prioritize failure reason provided by PG NCCL first, as it can abort // communicators when it encounters collective timeouts, etc. - if (processGroupFailureReason != c10::nullopt) { + if (processGroupFailureReason != std::nullopt) { return *processGroupFailureReason; } std::string interpret; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 9ce25b55dc133a..06568f6ce7d2f1 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -11,8 +11,8 @@ #include #include -#include #include +#include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) @@ -183,7 +183,7 @@ bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf); // thrown in the NCCL codebase. TORCH_API std::string getNcclErrorDetailStr( ncclResult_t error, - std::optional processGroupFailureReason = c10::nullopt); + std::optional processGroupFailureReason = std::nullopt); // Write NCCL debug info to local disk or any storage users define. // There are some constrains we set for the debug info writer: @@ -221,7 +221,7 @@ class NCCLComm { : ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess), - commFailureReason_(c10::nullopt), + commFailureReason_(std::nullopt), initialized_(false) {} NCCLComm() : NCCLComm(nullptr) {} @@ -249,7 +249,7 @@ class NCCLComm { auto comm = std::make_shared(); C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), - c10::nullopt); + std::nullopt); comm->ncclId_ = commId; comm->rank_ = rank; comm->initialized_ = true; @@ -271,12 +271,12 @@ class NCCLComm { C10D_NCCL_CHECK_NONBLOCKING( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), - c10::nullopt); + std::nullopt); } else { C10D_NCCL_CHECK( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), - c10::nullopt); + std::nullopt); // under blocking mode, comm is initialized after NCCL CHECK isInitialized = true; } @@ -301,7 +301,7 @@ class NCCLComm { LOG(INFO) << "Communicator was aborted before trying to dump its state."; return dump; } - C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt); + C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); return dump; } #endif @@ -336,7 +336,7 @@ class NCCLComm { } void ncclCommAbort( - std::optional commFailureReason = c10::nullopt) { + std::optional commFailureReason = std::nullopt) { std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index cff4ad09b70648..23ee93b91d7a64 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -128,7 +128,7 @@ class TORCH_API ProcessGroupCudaP2P : public Backend { const BarrierOptions& opts = BarrierOptions()) override; c10::intrusive_ptr intra_node_barrier( - c10::optional> ranks = c10::nullopt); + c10::optional> ranks = std::nullopt); at::Tensor get_p2p_buffer( size_t rank, @@ -136,7 +136,7 @@ class TORCH_API ProcessGroupCudaP2P : public Backend { c10::ScalarType dtype, int64_t storage_offest = 0); - void shutdown(c10::optional reason = c10::nullopt); + void shutdown(c10::optional reason = std::nullopt); private: c10::intrusive_ptr nccl_backend_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index cba0249829e68b..a6ed8fd26a161e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -2425,7 +2425,7 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:scatter", !inputs.empty() ? std::optional>(inputs[0]) - : c10::nullopt), + : std::nullopt), context(context), outputs(outputs), inputs(inputs), @@ -2888,7 +2888,7 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { OpType::BARRIER, seq, "gloo:barrier", - c10::nullopt), + std::nullopt), context(context), priorWork(std::move(priorWork)), tag(tag) {} diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 87c87b8f1ae9bd..9f1e63d58adf2d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -74,7 +74,7 @@ class TORCH_API ProcessGroupGloo : public Backend { uint64_t seq, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - c10::nullopt); + std::nullopt); ~AsyncWork() override = default; diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index 6d02f89f6005b8..91e9f938f1dd3e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -673,7 +673,7 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) - : c10::nullopt); + : std::nullopt); } else { auto entry = std::make_unique( nullptr, &outputTensors, std::move(runFunc)); @@ -682,7 +682,7 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) - : c10::nullopt); + : std::nullopt); } } @@ -932,7 +932,7 @@ c10::intrusive_ptr ProcessGroupMPI::barrier(const BarrierOptions& opts) { }; auto entry = std::make_unique(nullptr, nullptr, std::move(runFunc)); - return enqueue(std::move(entry), "mpi:barrier", c10::nullopt); + return enqueue(std::move(entry), "mpi:barrier", std::nullopt); } c10::intrusive_ptr ProcessGroupMPI::_allgather_base( diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp index 6e52e680e5c201..5eb06b7395570e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp @@ -87,7 +87,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::vector outputTensors, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - c10::nullopt) + std::nullopt) : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), outputTensors_(std::move(outputTensors)), future_(c10::make_intrusive( @@ -115,7 +115,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::vector outputTensors, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - c10::nullopt); + std::nullopt); ~AsyncWork() override; @@ -244,7 +244,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::unique_ptr entry, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - c10::nullopt); + std::nullopt); bool stop_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e7699b55245147..af940f53bf24c8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -32,6 +31,7 @@ #include #include #include +#include namespace c10d { @@ -376,7 +376,7 @@ std::string dump_nccl_trace( bool includeStackTraces, bool onlyActive) { return NCCLTraceBuffer::get()->dump( - c10::nullopt, includeCollectives, includeStackTraces, onlyActive); + std::nullopt, includeCollectives, includeStackTraces, onlyActive); } #endif @@ -393,7 +393,7 @@ std::optional)>>& get_cpp_trace_dumper() { static std::optional< std::function)>> - dumper(c10::nullopt); + dumper(std::nullopt); return dumper; } @@ -658,7 +658,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( if (blockingWait_) { while (!isCompleted()) { bool timedOut = checkTimeout( - timeout == kNoTimeout ? c10::nullopt : c10::make_optional(timeout)); + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); // Explicitly abort ncclComms here before throwing this timed out // exception to users. // If throwing timed out excepiton without aborting nccl communicators @@ -1245,7 +1245,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { : heartbeatTimeoutInSec_ * 1000; auto lastTimePollStore = std::chrono::steady_clock::now(); auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); - std::optional dumpPipe = c10::nullopt; + std::optional dumpPipe = std::nullopt; if (uid_ == 0) { // DumpPipe is one per-trainer process, and its convenient to name them // after 'global' ranks in the system, So we assume processgroup (uid)==0 is @@ -1881,7 +1881,7 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( // Prioritize commFailureReason over checkForNcclError() result if // commFailureReason is set. auto commFailureReason = ncclComm->getNcclCommFailureReason(); - if (commFailureReason != c10::nullopt) { + if (commFailureReason != std::nullopt) { return std::make_exception_ptr(C10_BUILD_ERROR( DistBackendError, c10::str( @@ -2050,7 +2050,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool singleP2POp = isP2POp(opType, batchP2P); // For point-to-point communication, lower rank of the two will get unique id. if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { - C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt); + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); } if (shouldBroadcastNCCLUniqueID(isSendRecvSelf)) { @@ -2086,7 +2086,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( for (const auto i : c10::irange(ncclActiveGroupCounter_)) { (void)i; // comms have not been initiated yet, so can only check in blocking-way - C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); } // GPU world size and GPU rank @@ -2182,7 +2182,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( // See [Group Start/End Note] for (const auto i : c10::irange(ncclActiveGroupCounter_)) { (void)i; - C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); } ncclStreams_.emplace(deviceKey, std::move(streamVal)); @@ -2334,7 +2334,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( seqCollective_, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) - : c10::nullopt, + : std::nullopt, desyncDebug_, enableTiming_.load(), dist_debug_level_); @@ -4190,23 +4190,23 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( } void ProcessGroupNCCL::groupStart() { - C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); ++ncclActiveGroupCounter_; } void ProcessGroupNCCL::groupEnd() { - C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); --ncclActiveGroupCounter_; } void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { #ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); #else if (!nccl_use_nonblocking()) { - C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); } else { - C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, c10::nullopt); + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); } #endif --ncclActiveGroupCounter_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index faaabe411bfccb..763ef9829618f1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -254,7 +254,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { OpType opType, uint64_t seq, const char* profilingTitle = nullptr, - const std::optional>& inputs = c10::nullopt, + const std::optional>& inputs = std::nullopt, bool desyncDebug = false, bool enableTiming = false, DebugLevel distDebugLevel = DebugLevel::Off); @@ -311,7 +311,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // and False otherwise. // In case of timeout, set exception on the WorkNCCL object. bool checkTimeout( - std::optional timeout = c10::nullopt); + std::optional timeout = std::nullopt); std::vector result() override; @@ -662,9 +662,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) // instead of relying on ProcessGroupNCCL destructor. // return true if abort is successful, otherwise false - bool abort(std::optional abortReason = c10::nullopt); + bool abort(std::optional abortReason = std::nullopt); - void shutdown(std::optional reason = c10::nullopt); + void shutdown(std::optional reason = std::nullopt); void eagerConnectSingleDevice(at::Device device) override; diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index fe24c31f9068bd..2de969d135e8f8 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -293,7 +293,7 @@ TCPStore::TCPStore( masterPort, isServer, numWorkers ? std::optional(*numWorkers) - : c10::nullopt, + : std::nullopt, waitWorkers, timeout}} {} @@ -376,7 +376,7 @@ TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() { detail::timing_guard tguard(clientCounters_["waitForWorkers"]); - if (numWorkers_ == c10::nullopt) { + if (numWorkers_ == std::nullopt) { return; } diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 25783f2d2acea9..9fd29b1c844cc6 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -49,7 +49,7 @@ struct TCPStoreOptions { std::uint16_t port = kDefaultPort; bool isServer = false; - std::optional numWorkers = c10::nullopt; + std::optional numWorkers = std::nullopt; bool waitWorkers = true; std::chrono::milliseconds timeout = Store::kDefaultTimeout; @@ -60,7 +60,7 @@ struct TCPStoreOptions { // If specified, and if isServer is true, the underlying TCPServer will take // over the bound socket associated to this fd. This option is useful to avoid // port assignment races in certain scenarios. - std::optional masterListenFd = c10::nullopt; + std::optional masterListenFd = std::nullopt; // A boolean value indicating whether to use the experimental libUV backend. bool useLibUV = true; @@ -73,7 +73,7 @@ class TORCH_API TCPStore : public Store { [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( const std::string& masterAddr, std::uint16_t masterPort, - std::optional numWorkers = c10::nullopt, + std::optional numWorkers = std::nullopt, bool isServer = false, const std::chrono::milliseconds& timeout = kDefaultTimeout, bool waitWorkers = true); diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index de623d77fe9e0e..9ff71f9d41b848 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -516,7 +516,7 @@ struct NCCLTraceBuffer { std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { - return c10::nullopt; + return std::nullopt; } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); @@ -621,7 +621,7 @@ struct NCCLTraceBuffer { bool can_compute_duration = false; Event* startEvent = nullptr; Event* endEvent = nullptr; - std::optional duration = c10::nullopt; + std::optional duration = std::nullopt; std::unique_lock guard(mutex_); diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 669957a7267358..7cdb9f62ebbb85 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -121,7 +121,7 @@ struct BroadcastOptions { struct AllreduceOptions { ReduceOp reduceOp = ReduceOp::SUM; std::chrono::milliseconds timeout = kUnsetTimeout; - std::optional sparseIndices = c10::nullopt; + std::optional sparseIndices = std::nullopt; }; struct AllreduceCoalescedOptions : AllreduceOptions {}; diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index a03337e975148b..b77a914da4e677 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -440,7 +440,7 @@ inline at::Tensor newLikeFlat( sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); strides.insert(strides.end(), t.strides().begin(), t.strides().end()); return at::empty_strided( - sizes, strides, t.options().memory_format(c10::nullopt)); + sizes, strides, t.options().memory_format(std::nullopt)); } inline at::Tensor newLikeFlat(std::vector& tensors) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index d29b838321176d..c10e5007b9f544 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -51,7 +51,7 @@ class TORCH_API Work : public torch::CustomClassHolder { OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - c10::nullopt); + std::nullopt); ~Work() override; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f1b28886b989b..5145c969a95b00 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1415,7 +1415,7 @@ Example:: bool multiTenant, std::optional masterListenFd, bool useLibUV) { - std::optional numWorkers = c10::nullopt; + std::optional numWorkers = std::nullopt; if (worldSize.has_value() && worldSize.value() > -1) { numWorkers = static_cast(worldSize.value()); } @@ -2648,7 +2648,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("store"), py::arg("rank"), py::arg("world_size"), - py::arg("buffer_size") = c10::nullopt) + py::arg("buffer_size") = std::nullopt) .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none()); #ifdef NCCL_HAS_COMM_CTA_CGA diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index 5d7e2d426d30a1..b4d70f580da5cb 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -33,7 +33,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { c10::intrusive_ptr store, size_t rank, size_t worldSize, - std::optional bufferSize = c10::nullopt); + std::optional bufferSize = std::nullopt); ~IntraNodeComm() override; @@ -65,7 +65,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { /** * Perform a barrier among the specified ranks. */ - void barrier(std::optional> ranks = c10::nullopt); + void barrier(std::optional> ranks = std::nullopt); at::Tensor getBuffer( size_t rank, diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 711039bf485954..48f8786842f01f 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -234,7 +234,7 @@ void Logger::set_event_time( Timer& timer, Timer::Event event) { auto timestamp = timer.getTimestamp(event); - if (timestamp != c10::nullopt) { + if (timestamp != std::nullopt) { // TODO: should we set this as human-readable time instead of unixtime? event_time = *timestamp; } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 6a2812ab24b9cd..6c5f7a79ff9fbf 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -61,7 +61,7 @@ class CpuTimer : public Timer { // calculate the valid avg_time. // In this case, skip calculating the avg_time and return. if (end_time < start_time) { - return c10::nullopt; + return std::nullopt; } return end_time - start_time; } @@ -499,7 +499,7 @@ std::vector Reducer::get_grad_buckets( bucket.lengths, bucket.sizes_vec, variables_for_bucket, - c10::nullopt); + std::nullopt); } return gradBuckets; } @@ -1655,9 +1655,9 @@ void Reducer::finalize_backward() { } } - if (installed_futures_ != c10::nullopt) { + if (installed_futures_ != std::nullopt) { c10::collectAll(*installed_futures_)->wait(); - installed_futures_ = c10::nullopt; + installed_futures_ = std::nullopt; } // See Note [Skip allreducing local_used_maps_dev] diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 1f72b0eb37b9f6..aa3c40ae95bbf2 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -262,9 +262,9 @@ class TORCH_API Reducer { // List of futures installed by Reducer::install_futures that should be // awaited at the end of backwards pass. std::optional>> - installed_futures_{c10::nullopt}; + installed_futures_{std::nullopt}; // Mixed precision parameter dtype for bucket type checking. - std::optional mixed_precision_param_dtype_{c10::nullopt}; + std::optional mixed_precision_param_dtype_{std::nullopt}; // Work handle for allreduce on local_used_map_ c10::intrusive_ptr local_used_work_; @@ -389,7 +389,7 @@ class TORCH_API Reducer { bool expect_sparse_gradient = false; // Sparse indices tensor - std::optional sparse_tensor_indices = c10::nullopt; + std::optional sparse_tensor_indices = std::nullopt; // TODO(@pietern) // Memory copies from gradient tensors into the bucket are potentially diff --git a/torch/csrc/distributed/c10d/reducer_cuda.cpp b/torch/csrc/distributed/c10d/reducer_cuda.cpp index 84bff02072b606..a158e44fc047c0 100644 --- a/torch/csrc/distributed/c10d/reducer_cuda.cpp +++ b/torch/csrc/distributed/c10d/reducer_cuda.cpp @@ -59,7 +59,7 @@ class CudaTimer : public Timer { // If it is never recorded/created, skip synchronize and calculation. // Otherwise it will throw cuda errors. if (!start_event.isCreated() || !end_event.isCreated()) { - return c10::nullopt; + return std::nullopt; } // set_runtime_stats_and_log is called at the beginning of forward call, // when it is cheap to synchronize the cuda events of previous iteration, @@ -74,7 +74,7 @@ class CudaTimer : public Timer { // calculate the valid avg_time. // In this case, skip calculating the avg_time and return. if (milliseconds < 0) { - return c10::nullopt; + return std::nullopt; } return int64_t(milliseconds * kMilliSecondToNanosSecond); } diff --git a/torch/csrc/distributed/c10d/reducer_timer.hpp b/torch/csrc/distributed/c10d/reducer_timer.hpp index f9b9f11c8c9632..dbea3958db43da 100644 --- a/torch/csrc/distributed/c10d/reducer_timer.hpp +++ b/torch/csrc/distributed/c10d/reducer_timer.hpp @@ -47,7 +47,7 @@ class TORCH_API Timer { std::optional getTimestamp(Event event) { auto time = getTimeRef(event); if (time == kUnsetTime) { - return c10::nullopt; + return std::nullopt; } else { return time; } diff --git a/torch/csrc/distributed/c10d/sequence_num.cpp b/torch/csrc/distributed/c10d/sequence_num.cpp index fd76247199f618..3807d629d830c5 100644 --- a/torch/csrc/distributed/c10d/sequence_num.cpp +++ b/torch/csrc/distributed/c10d/sequence_num.cpp @@ -10,7 +10,7 @@ SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} SequenceNum::SequenceNum(const SequenceNum& other) { if (!other.isSet()) { - num_ = c10::nullopt; + num_ = std::nullopt; } else { num_ = other.get(); } @@ -23,7 +23,7 @@ uint64_t SequenceNum::get() const { void SequenceNum::increment() { std::lock_guard lock(lock_); - TORCH_CHECK(num_ != c10::nullopt); + TORCH_CHECK(num_ != std::nullopt); num_ = ++(*num_); } @@ -32,7 +32,7 @@ void SequenceNum::increment() { uint64_t SequenceNum::getAndIncrement() { uint64_t curVal = 0; std::lock_guard lock(lock_); - TORCH_CHECK(num_ != c10::nullopt); + TORCH_CHECK(num_ != std::nullopt); curVal = *num_; num_ = ++(*num_); return curVal; @@ -45,13 +45,13 @@ void SequenceNum::set(const uint64_t num) { bool SequenceNum::isSet() const { std::lock_guard lock(lock_); - return num_ != c10::nullopt; + return num_ != std::nullopt; } SequenceNum& SequenceNum::operator=(const SequenceNum& other) { std::lock_guard lock(lock_); if (!other.isSet()) { - num_ = c10::nullopt; + num_ = std::nullopt; } else { num_ = other.get(); } diff --git a/torch/csrc/distributed/c10d/sequence_num.hpp b/torch/csrc/distributed/c10d/sequence_num.hpp index ce31f4b5527282..38bd4cb5ed9d38 100644 --- a/torch/csrc/distributed/c10d/sequence_num.hpp +++ b/torch/csrc/distributed/c10d/sequence_num.hpp @@ -1,9 +1,9 @@ #pragma once #include -#include #include #include +#include #include namespace c10d { diff --git a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp index 3a37e7b02a5f05..eb45679873f039 100644 --- a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp +++ b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp @@ -10,7 +10,7 @@ namespace rpc { const std::string REMOTE_PROFILING_KEY_PREFIX = "#remote_op: "; constexpr int kAutoIncrementBits = 48; /*static */ thread_local std::optional - RemoteProfilerManager::currentThreadLocalKey_ = c10::nullopt; + RemoteProfilerManager::currentThreadLocalKey_ = std::nullopt; /*static */ RemoteProfilerManager& RemoteProfilerManager::getInstance() { static RemoteProfilerManager* handler = new RemoteProfilerManager(); return *handler; @@ -32,7 +32,7 @@ bool RemoteProfilerManager::isCurrentKeySet() const { } void RemoteProfilerManager::unsetCurrentKey() { - currentThreadLocalKey_ = c10::nullopt; + currentThreadLocalKey_ = std::nullopt; } void RemoteProfilerManager::eraseKey(const ProfilingId& globallyUniqueId) { diff --git a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h index c6f8b353806b5b..2889120b67ca69 100644 --- a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h +++ b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h @@ -1,8 +1,8 @@ #pragma once -#include #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index ed7847a1f5faa2..887f25b6c16dd4 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -119,7 +119,7 @@ TypePtr tryInferTypeWithTypeHint( /////////////////////////// PyRRef ////////////////////////////////// PyRRef::PyRRef(c10::intrusive_ptr rref) - : rref_(std::move(rref)), profilingFuture_(c10::nullopt) { + : rref_(std::move(rref)), profilingFuture_(std::nullopt) { TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); C10_LOG_API_USAGE_ONCE("torch.distributed.rref"); } diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 57acbc0370252f..51ee554abda743 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -261,7 +261,7 @@ c10::intrusive_ptr pyRpcTorchscript( functionSchema, argsTuple.cast(), kwargsDict.cast(), - c10::nullopt); + std::nullopt); } DCHECK(!PyGILState_Check()); c10::intrusive_ptr fut = rpcTorchscript( @@ -408,7 +408,7 @@ PyRRef pyRemoteTorchscript( // Acquire GIL for py::args and py::kwargs processing. py::gil_scoped_acquire ag; stack = torch::jit::createStackForSchema( - functionSchema, args, kwargs, c10::nullopt); + functionSchema, args, kwargs, std::nullopt); } DCHECK(!PyGILState_Check()); auto rrefPtr = remoteTorchscript( diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index fb73cf2abf483e..3b6b04047c4e0c 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -440,7 +440,7 @@ c10::intrusive_ptr RequestCallbackNoPython:: true /* cleanup TLS state */, false /* consolidate events */); { TLSLegacyProfilerGuard g( - profilingConfig, c10::nullopt, requestThreadOptions); + profilingConfig, std::nullopt, requestThreadOptions); TORCH_INTERNAL_ASSERT( profilerEnabled(), "Expected profiler to be enabled!"); // Kick off processing for nested work and get Future result in diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index d6da3f2ea455f0..507d6bc846587c 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -3,10 +3,10 @@ #include #include #include -#include #include #include #include +#include #include diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index dacded5cc1e62a..5db4adf95f85ba 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -1,10 +1,10 @@ #pragma once -#include #include #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 50cc97785f61da..8259efeee1f9b3 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -94,7 +94,7 @@ class TensorpipeCudaConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return c10::nullopt; + return std::nullopt; } at::DataPtr allocateTensorForReceiving( diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 929ae30f8a6d4d..9d38b5538d554a 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -59,7 +59,7 @@ class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return c10::make_optional(std::move(storageData)); + return std::make_optional(std::move(storageData)); } else { tensorpipe::CpuBuffer buffer; buffer.ptr = static_cast(storage.mutable_data()); @@ -70,7 +70,7 @@ class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return c10::nullopt; + return std::nullopt; } } diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 2e5cb3bfab02e1..7913caad5449fc 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -591,7 +591,7 @@ CacheNode* _compiled_autograd_impl( if (next.is_valid() && output.defined()) { input_buffers.lookup(next.function.get()) .add( - next.input_nr, std::move(output), c10::nullopt, c10::nullopt); + next.input_nr, std::move(output), std::nullopt, std::nullopt); } } } diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 53da5a634746c2..b54a7285f63549 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -242,7 +242,7 @@ int64_t _grad_increment_nesting() { // See NOTE [grad and vjp interaction with no_grad] bool prev_grad_mode = c10::GradMode::is_enabled(); return initAndPushDynamicLayer( - TransformType::Grad, c10::nullopt, c10::nullopt, prev_grad_mode); + TransformType::Grad, std::nullopt, std::nullopt, prev_grad_mode); } int64_t _grad_decrement_nesting() { @@ -257,9 +257,9 @@ int64_t _jvp_increment_nesting() { c10::AutogradState::get_tls_state().get_fw_grad_mode(); return initAndPushDynamicLayer( TransformType::Jvp, - c10::nullopt, - c10::nullopt, - c10::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, prev_fwd_grad_mode); } @@ -287,10 +287,10 @@ int64_t _vmap_decrement_nesting() { int64_t _func_increment_nesting(bool reapply_views) { return initAndPushDynamicLayer( TransformType::Functionalize, - c10::nullopt, - c10::nullopt, - c10::nullopt, - c10::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, /*functionalize_add_back_views=*/reapply_views); } @@ -528,7 +528,7 @@ void initFuncTorchBindings(PyObject* module) { "get_interpreter_stack", []() -> std::optional> { const auto& stack = getDynamicLayerStack(); if (stack.empty()) { - return c10::nullopt; + return std::nullopt; } std::vector result; result.reserve(stack.size()); @@ -540,7 +540,7 @@ void initFuncTorchBindings(PyObject* module) { m.def("peek_interpreter_stack", []() -> std::optional { const auto& stack = getDynamicLayerStack(); if (stack.empty()) { - return c10::nullopt; + return std::nullopt; } auto result = stack.back().interpreter(); return result; diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 6e7bd355c57c31..eca21f6bf348c4 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -7,9 +7,9 @@ #include #include #include -#include #include #include +#include #define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ try { \ @@ -66,41 +66,41 @@ inline void assert_inf_and_nan( // utility functions to convert a pointer to an optional value template inline std::optional pointer_to_optional(T* ptr) { - return ptr ? c10::make_optional(*ptr) : c10::nullopt; + return ptr ? std::make_optional(*ptr) : std::nullopt; } template >> inline std::optional pointer_to_optional(U* ptr) { - return ptr ? c10::make_optional(T(*ptr)) : c10::nullopt; + return ptr ? std::make_optional(T(*ptr)) : std::nullopt; } template <> inline std::optional pointer_to_optional(AtenTensorHandle* ptr) { - return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) - : c10::nullopt; + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; } template <> inline std::optional pointer_to_optional( const AtenTensorHandle* ptr) { - return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) - : c10::nullopt; + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; } template <> inline std::optional pointer_to_optional( AtenGeneratorHandle* ptr) { - return ptr ? c10::make_optional(*generator_handle_to_generator_pointer(*ptr)) - : c10::nullopt; + return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : std::nullopt; } inline std::optional pointer_to_optional_device( int32_t* device_type, int32_t device_index) { - return device_type ? c10::make_optional(c10::Device( + return device_type ? std::make_optional(c10::Device( static_cast(*device_type), static_cast(device_index))) - : c10::nullopt; + : std::nullopt; } // utility functions to convert a pointer to a list @@ -180,8 +180,8 @@ inline std::optional> pointer_to_optional_list( U** ptr, int64_t len) { return ptr - ? c10::make_optional>(pointer_to_list(*ptr, len)) - : c10::nullopt; + ? std::make_optional>(pointer_to_list(*ptr, len)) + : std::nullopt; } } // namespace torch::aot_inductor diff --git a/torch/csrc/jit/api/compilation_unit.h b/torch/csrc/jit/api/compilation_unit.h index 8e28ef4717b934..d1c2c829d660c3 100644 --- a/torch/csrc/jit/api/compilation_unit.h +++ b/torch/csrc/jit/api/compilation_unit.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include @@ -97,7 +97,7 @@ struct TORCH_API CompilationUnit { const Self* self, // see [name mangling] bool shouldMangle = false, - std::optional operator_set_version = c10::nullopt); + std::optional operator_set_version = std::nullopt); void define_hooks( const std::optional& prefix, @@ -293,7 +293,7 @@ struct TORCH_API CompilationUnit { const std::unordered_map& function_table, bool shouldMangle = false, FunctionType type = FunctionType::Method, - std::optional version = c10::nullopt) const; + std::optional version = std::nullopt) const; // Define a property on \p self. struct PropertyPair; diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index 6ed8cb36199ef2..01e7a3c98e3024 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -13,7 +13,7 @@ struct TORCH_API GraphFunction : public Function { std::shared_ptr graph, std::function function_creator, std::optional executor_execution_mode = - c10::nullopt) + std::nullopt) : name_(std::move(name)), graph_(std::move(graph)), executor_execution_mode_(executor_execution_mode), diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 45b99eb8e47aa6..ae878376bab318 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -158,11 +158,11 @@ void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) { } void Module::to(at::ScalarType dtype, bool non_blocking) { - to_impl(/*device=*/c10::nullopt, dtype, non_blocking); + to_impl(/*device=*/std::nullopt, dtype, non_blocking); } void Module::to(at::Device device, bool non_blocking) { - to_impl(device, /*dtype=*/c10::nullopt, non_blocking); + to_impl(device, /*dtype=*/std::nullopt, non_blocking); } static void module_state_to( diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 92b9c96c3a6ecf..9b2648737b0ce0 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -15,8 +15,8 @@ #include #include #include -#include #include +#include #include #include @@ -238,7 +238,7 @@ struct TORCH_API Module : public Object { Module copy() const; - Module deepcopy(std::optional device = c10::nullopt) const; + Module deepcopy(std::optional device = std::nullopt) const; // Clones both the underlying `ClassType` and the module instance(data), this // function creates a new `ClassType` and returns a new instance that has the @@ -334,7 +334,7 @@ struct TORCH_API Module : public Object { TORCH_API Module freeze( const Module& module, const std::optional>& preserved_attrs = - c10::nullopt, + std::nullopt, bool optimize_numerics = true); // C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation @@ -552,7 +552,7 @@ struct slot_list_impl { : module_(std::move(module)), recurse_(recurse), return_module_(return_module), - size_(c10::nullopt) { + size_(std::nullopt) { if (!recurse && !return_module && Policy::all_slots) { size_ = module_.num_slots(); } diff --git a/torch/csrc/jit/api/object.cpp b/torch/csrc/jit/api/object.cpp index b707e767727650..f95d576d6c8cb2 100644 --- a/torch/csrc/jit/api/object.cpp +++ b/torch/csrc/jit/api/object.cpp @@ -20,7 +20,7 @@ std::optional Object::find_method(const std::string& basename) const { return Method(_ivalue(), fn); } } - return c10::nullopt; + return std::nullopt; } void Object::define(const std::string& src, const ResolverPtr& resolver) { diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 164f6e2ac073af..2c0f7e3b164f05 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include @@ -129,7 +129,7 @@ struct TORCH_API Object { const Property get_property(const std::string& name) const { for (const auto& prop : type()->properties()) { if (prop.name == name) { - std::optional setter = c10::nullopt; + std::optional setter = std::nullopt; if (prop.setter) { setter = Method(_ivalue(), prop.setter); } @@ -142,7 +142,7 @@ struct TORCH_API Object { const std::vector get_properties() const { return c10::fmap(type()->properties(), [&](ClassType::Property prop) { - std::optional setter = c10::nullopt; + std::optional setter = std::nullopt; if (prop.setter) { setter = Method(_ivalue(), prop.setter); } diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index b4bc3e8f4727e3..7e03b576d12184 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -231,7 +231,7 @@ std::shared_ptr compileKernel( size_t input_index = 0; for (const auto& p : graph->inputs()) { if (p->type()->isSubtypeOf(*FloatType::get())) { - flat_inputs.emplace_back(p, c10::nullopt); + flat_inputs.emplace_back(p, std::nullopt); } if (!p->type()->isSubtypeOf(*TensorType::get())) { continue; diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index 5f692d50e6b54e..db9d57a679cb15 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -3,9 +3,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -65,7 +65,7 @@ std::optional exec(const std::wstring& cmd) { std::unique_ptr pipe( _wpopen(cmd.c_str(), L"r"), _pclose); if (!pipe) { - return c10::nullopt; + return std::nullopt; } while (fgetws(buffer.data(), static_cast(buffer.size()), pipe.get()) != nullptr) { diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 8abb99283ffc75..411dbe62a2e157 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -12,6 +11,7 @@ #include #include #include +#include #include #include // TODO: remove, debugging only @@ -44,7 +44,7 @@ static std::optional> getMapSize( try { map_size = at::infer_size(map_size, arg.sizes()); } catch (...) { - return c10::nullopt; + return std::nullopt; } } else { auto tensor_sizes = arg.sizes().vec(); @@ -52,13 +52,13 @@ static std::optional> getMapSize( const auto dim = at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size()); if (tensor_sizes[dim] % num_chunks != 0) { - return c10::nullopt; + return std::nullopt; } tensor_sizes[dim] /= num_chunks; try { map_size = at::infer_size(map_size, tensor_sizes); } catch (...) { - return c10::nullopt; + return std::nullopt; } } } @@ -83,12 +83,12 @@ static std::optional> canRunKernel( if (!map_size) { map_size = getMapSize(spec, args, broadcast_group); if (!map_size) - return c10::nullopt; + return std::nullopt; } else { const auto group_map_size = getMapSize(spec, args, broadcast_group); // Note: this checks that group_map_size is defined AND equal to map_size if (map_size != group_map_size) - return c10::nullopt; + return std::nullopt; } } diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index 2fc52f2d76f0f2..eacdbc7ec3f336 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -2,13 +2,13 @@ #include #include -#include #include #include #include #include #include #include +#include #include #include @@ -122,7 +122,7 @@ struct TORCH_API KernelSpec { std::lock_guard guard{mutex_}; const auto it = kernels_.find(arg_spec); if (it == kernels_.end()) - return c10::nullopt; + return std::nullopt; return it->second; } void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr kernel) diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index 16484dd4653c8f..30f32f5994c1d4 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -26,7 +26,7 @@ static std::optional getDimensions(Value* v) { if (v->type()->isSubtypeOf(TensorType::get())) { return v->type()->cast()->sizes().size(); } else { - return c10::nullopt; + return std::nullopt; } } diff --git a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp index dfbfe467e9765b..71e74501656913 100644 --- a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp @@ -132,7 +132,7 @@ std::optional GraphRewriter::tryMerge(Node* consumer, Node* producer) { bool canMerge = llgaHelper_.shouldMerge(producer, consumer) && aliasDb_.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { - return c10::nullopt; + return std::nullopt; } llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_); return consumer; diff --git a/torch/csrc/jit/codegen/onednn/prepare_binary.cpp b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp index a4f6d268694e36..d09b5777f97347 100644 --- a/torch/csrc/jit/codegen/onednn/prepare_binary.cpp +++ b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp @@ -69,7 +69,7 @@ static void handleBinaryOpInputs(Node* node) { auto second_input_typeptr = node->input(1)->type()->expect(); std::optional second_input_type = second_input_typeptr->scalarType(); - if (second_input_type != c10::nullopt) { + if (second_input_type != std::nullopt) { // dtype of the second tensor might not be available in the IR auto dtypeOfSecondInput = second_input_type.value(); if (dtypeOfFirstInput != dtypeOfSecondInput) { diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h index 80b2e2a82f788f..edac94a7357bf8 100644 --- a/torch/csrc/jit/cuda/cuda.h +++ b/torch/csrc/jit/cuda/cuda.h @@ -15,7 +15,7 @@ class CUDAStream final : public CustomClassHolder { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream( - std::optional device = c10::nullopt, + std::optional device = std::nullopt, int64_t priority = 0) { c10::DeviceIndex device_index = device.has_value() ? device->index() : c10::cuda::current_device(); @@ -157,7 +157,7 @@ TORCH_LIBRARY(cuda, m) { auto stream_class = m.class_("Stream").def( torch::init, int64_t>(), "", - {torch::arg("device") = c10::nullopt, torch::arg("priority") = 0}); + {torch::arg("device") = std::nullopt, torch::arg("priority") = 0}); auto event_class = m.class_("Event").def( torch::init(), "", diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index c1c1d87176b759..2b3bdc42e4cc1e 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -121,7 +121,7 @@ struct BuiltinFunctionRegistry { void loadSource(const std::string& source, const std::string& the_namespace) { std::shared_ptr cu = std::make_shared(); modules.emplace_back(cu); - cu->define(c10::nullopt, source, nativeResolver(), /*self=*/nullptr); + cu->define(std::nullopt, source, nativeResolver(), /*self=*/nullptr); for (auto& method : cu->get_functions()) { builtins_by_name_[Symbol::fromQualString( the_namespace + "::" + method->name())] diff --git a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp index 943551e80692f1..f2ef8b0e953c4b 100644 --- a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp +++ b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp @@ -28,7 +28,7 @@ static void canonicalizeModifiedLoop(Node* n) { g->insertConstant(std::numeric_limits::max())); auto inp_condition = toIValue(loop.inputCond()); - if (inp_condition == c10::nullopt || inp_condition->toBool() == false) { + if (inp_condition == std::nullopt || inp_condition->toBool() == false) { condition = g->insert(aten::__and__, {condition, loop.inputCond()}); } loop.replaceInputCondition(condition); diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index c15116ac3e2446..cfdef51afc31c1 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -151,7 +151,7 @@ TypePtr ConcreteModuleType::getJitType() const { std::optional ConcreteModuleType::getPyClass() const { if (!data_.pyClass_) { - return c10::nullopt; + return std::nullopt; } return data_.pyClass_; } @@ -162,7 +162,7 @@ std::optional> ConcreteModuleType::findOverloads( if (it != data_.overloads_.end()) { return it->second; } - return c10::nullopt; + return std::nullopt; } std::optional ConcreteModuleType::findFunctionAttribute( @@ -171,7 +171,7 @@ std::optional ConcreteModuleType::findFunctionAttribute( if (it != data_.functionAttributes_.end()) { return it->second.function_->function(); } - return c10::nullopt; + return std::nullopt; } std::optional ConcreteModuleType::findBuiltinFunction( @@ -180,7 +180,7 @@ std::optional ConcreteModuleType::findBuiltinFunction( if (it != data_.builtinFunctions_.end()) { return it->second; } - return c10::nullopt; + return std::nullopt; } std::optional ConcreteModuleType::findFailedAttribute( @@ -189,7 +189,7 @@ std::optional ConcreteModuleType::findFailedAttribute( if (it != data_.failedAttributes_.end()) { return it->second; } - return c10::nullopt; + return std::nullopt; } bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index ba86a891d31dd0..00ccce8567fb61 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -2,10 +2,10 @@ #include #include -#include #include #include #include +#include #include #include @@ -25,7 +25,7 @@ struct SchemaParser { explicit SchemaParser(const std::string& str, bool allow_typevars) : L(std::make_shared( c10::string_view(str), - c10::nullopt, + std::nullopt, 0, nullptr, Source::DONT_COPY)), diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 350305b83567c8..788483aef224ff 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -32,8 +32,8 @@ #include -#include #include +#include #include #include @@ -292,7 +292,7 @@ struct Environment { if (msg != runner->error_messages.end()) { return msg->second(); } else { - return c10::nullopt; + return std::nullopt; } } @@ -1267,7 +1267,7 @@ struct to_ir { {}); auto refinements = RefinementSet(findIsNoneRefinements( cond_op.lhs(), lhs_val, cond_op.rhs(), rhs_val, expr.kind())); - return CondValue(cond_value, refinements, c10::nullopt); + return CondValue(cond_value, refinements, std::nullopt); } } break; default: { @@ -1294,7 +1294,7 @@ struct to_ir { } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); - std::optional static_if = c10::nullopt; + std::optional static_if = std::nullopt; auto kind = expr_out->node()->kind(); if (kind == aten::is_scripting) { static_if = true; @@ -2291,7 +2291,7 @@ struct to_ir { Value* result = graph->insertNode(graph->createIsInstance(lhs_val, rhs_types)) ->output(); - return CondValue(result, std::move(refinement), c10::nullopt); + return CondValue(result, std::move(refinement), std::nullopt); } void emitIf(const If& stmt) { @@ -2752,7 +2752,7 @@ struct to_ir { getAugOp(stmt, lhs->type()), /*args=*/{lhs, rhs}, /*kwargs=*/{}, - /*self=*/c10::nullopt); + /*self=*/std::nullopt); } } @@ -2968,7 +2968,7 @@ struct to_ir { auto outputs = rhs_output->asTuple( rhs_loc, method, - starred_unpack ? c10::nullopt : std::optional{n_binders}); + starred_unpack ? std::nullopt : std::optional{n_binders}); if (outputs.size() < n_binders) { throw ErrorReport(tl) << "need " << (starred_unpack ? "at least " : "") << n_binders @@ -4796,11 +4796,11 @@ struct to_ir { tuple_args.reserve(3); start ? tuple_args.emplace_back(start) - : tuple_args.emplace_back(c10::nullopt); + : tuple_args.emplace_back(std::nullopt); end ? tuple_args.emplace_back(end) - : tuple_args.emplace_back(c10::nullopt); + : tuple_args.emplace_back(std::nullopt); step ? tuple_args.emplace_back(step) - : tuple_args.emplace_back(c10::nullopt); + : tuple_args.emplace_back(std::nullopt); return emitTupleSlice(loc, args[0], tuple_args); } @@ -4886,7 +4886,7 @@ struct to_ir { }; std::vector dims(subscript_exprs.size()); std::vector> exprs( - subscript_exprs.size(), c10::nullopt); + subscript_exprs.size(), std::nullopt); auto handle_indexing = [&](const Expr& subscript_expr, int expr_idx, @@ -5231,7 +5231,7 @@ struct to_ir { val_range, "begin", emitExpr(Expr(slice.start().get()))); tuple_args.emplace_back(begin); } else { - tuple_args.emplace_back(c10::nullopt); + tuple_args.emplace_back(std::nullopt); } if (slice.end().present()) { @@ -5239,7 +5239,7 @@ struct to_ir { NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); tuple_args.emplace_back(end); } else { - tuple_args.emplace_back(c10::nullopt); + tuple_args.emplace_back(std::nullopt); } if (slice.step().present()) { @@ -5247,7 +5247,7 @@ struct to_ir { NamedValue(val_range, "step", emitExpr(Expr(slice.step().get()))); tuple_args.emplace_back(step); } else { - tuple_args.emplace_back(c10::nullopt); + tuple_args.emplace_back(std::nullopt); } auto tupleSliceValue = emitTupleSlice(val_range, s_tuple_val, tuple_args); @@ -5327,7 +5327,7 @@ struct FunctionResolver : public Resolver { CompilationUnit::CompilationUnit(const std::string& source) : CompilationUnit() { // calles the define with native resolver to generate the graph for functions - define(c10::nullopt, source, nativeResolver(), nullptr); + define(std::nullopt, source, nativeResolver(), nullptr); } // This pair represents a pair of functions (getter and setter) obtained from diff --git a/torch/csrc/jit/frontend/parse_string_literal.h b/torch/csrc/jit/frontend/parse_string_literal.h index 5b924864bebd8a..13bbbf89cc343f 100644 --- a/torch/csrc/jit/frontend/parse_string_literal.h +++ b/torch/csrc/jit/frontend/parse_string_literal.h @@ -1,7 +1,7 @@ #pragma once -#include #include #include +#include namespace torch { namespace jit { @@ -15,17 +15,17 @@ inline bool isCharCount(char c, const std::string& str, size_t start, int len) { inline std::optional parseOctal(const std::string& str, size_t pos) { //\xxx where x are 0-7 if (pos + 3 >= str.size()) - return c10::nullopt; + return std::nullopt; size_t c = 0; for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) { // NOLINTNEXTLINE(bugprone-signed-char-misuse) int d = str[pos + i]; if (d < '0' || d > '7') - return c10::nullopt; + return std::nullopt; c += b * (d - '0'); } if (c >= 256) - return c10::nullopt; + return std::nullopt; return c; } diff --git a/torch/csrc/jit/frontend/parser.cpp b/torch/csrc/jit/frontend/parser.cpp index ae2c98028e0717..5bf6144d8c7d5b 100644 --- a/torch/csrc/jit/frontend/parser.cpp +++ b/torch/csrc/jit/frontend/parser.cpp @@ -1,10 +1,10 @@ #include -#include #include #include #include #include +#include namespace torch::jit { @@ -241,7 +241,7 @@ struct ParserImpl { return create_compound('=', r, {}); // no reduction } break; default: - return c10::nullopt; + return std::nullopt; } } TreeRef parseTrinary( diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index 87ec9992141d89..a91f204a404cfd 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -13,6 +12,7 @@ #include #include #include +#include namespace torch::jit { @@ -261,7 +261,7 @@ std::optional findInputWithName( return i; } } - return c10::nullopt; + return std::nullopt; } /// Creates a list with the provided values if each value's type can be matched @@ -364,7 +364,7 @@ static std::optional tryMatchSchema( std::ostream* failure_messages, bool allow_conversions) { if (isBlockListedSchema(schema)) { - return c10::nullopt; + return std::nullopt; } auto err = [&]() -> std::ostream& { @@ -392,7 +392,7 @@ static std::optional tryMatchSchema( std::optional actual_named_value; if (arg.name() == "self" && self) { actual_named_value = self; - self = c10::nullopt; + self = std::nullopt; } else if (!arg.kwarg_only() && used_args < args.size()) { // Try to convert all the remaining non-kwarg arguments (used_args) to a // list. Allow zeros(IntArrayRef sizes) to work with zeros(1, 2) or @@ -417,7 +417,7 @@ static std::optional tryMatchSchema( allow_conversions, type_env); if (!list) { - return c10::nullopt; + return std::nullopt; } used_args = args.size(); positional_inputs.push_back(list); @@ -437,7 +437,7 @@ static std::optional tryMatchSchema( err() << "Argument " << nv.name() << " specified twice in schema, submit a bug report!\n"; } - return c10::nullopt; + return std::nullopt; } used_kwarg[*kwarg_idx] = true; actual_named_value = nv; @@ -450,7 +450,7 @@ static std::optional tryMatchSchema( err() << "Argument " << schema.arguments()[schema_i].name() << " not provided.\n"; } - return c10::nullopt; + return std::nullopt; } // Make sure the actual_named_value found matches the type of arg @@ -464,16 +464,16 @@ static std::optional tryMatchSchema( allow_conversions, type_env); if (!positional) { - return c10::nullopt; + return std::nullopt; } positional_inputs.push_back(positional); } // check for unused self argument - if (self != c10::nullopt) { + if (self != std::nullopt) { if (failure_messages) { err() << "Provided self argument not used in schema.\n"; } - return c10::nullopt; + return std::nullopt; } if (schema.is_vararg()) { @@ -488,7 +488,7 @@ static std::optional tryMatchSchema( err() << "Expected at most " << used_args << " arguments " << "but found " << args.size() << " positional arguments.\n"; } - return c10::nullopt; + return std::nullopt; } // check for unused kwargs for (const auto i : c10::irange(kwargs.size())) { @@ -501,7 +501,7 @@ static std::optional tryMatchSchema( err() << "Keyword argument " << nv.name() << " specified twice.\n"; } } - return c10::nullopt; + return std::nullopt; } } @@ -518,7 +518,7 @@ static std::optional tryMatchSchema( std::all_of(returns.begin(), returns.end(), [&](const Argument& r) { return r.name().length() > 0; }); - c10::OptNameList return_field_names = c10::nullopt; + c10::OptNameList return_field_names = std::nullopt; if (return_has_field_names) { return_field_names = fmap(returns, [&](const Argument& r) { return r.name(); }); @@ -633,7 +633,7 @@ static Value* packOutputs( if (field_names) { auto types = fmap(values, [](Value* v) { return v->type(); }); named_tuple = - TupleType::createNamed(c10::nullopt, field_names.value(), types); + TupleType::createNamed(std::nullopt, field_names.value(), types); } return g.insertNode(g.createTuple(values, named_tuple))->output(); } diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 0c69df521df6a2..8a24863cbe71d0 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -10,7 +10,7 @@ namespace jit { // Try to match a list of inputs and keyword 'attributes' to this // schema. Return the flat list of positional inputs to the call or -// `c10::nullopt` on failure (`failure_messages` contains a good error +// `std::nullopt` on failure (`failure_messages` contains a good error // report in this case) struct MatchedSchema { @@ -28,7 +28,7 @@ TORCH_API MatchedSchema matchSchema( Graph& graph, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = c10::nullopt); + const std::optional& self = std::nullopt); TORCH_API std::pair matchSchemas( const std::vector& schemas, @@ -36,7 +36,7 @@ TORCH_API std::pair matchSchemas( Graph& graph, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = c10::nullopt, + const std::optional& self = std::nullopt, bool render_errors = false); TORCH_API bool convertibleToList( @@ -51,7 +51,7 @@ TORCH_API Value* emitBuiltinCall( Symbol name, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = c10::nullopt); + const std::optional& self = std::nullopt); TORCH_API std::optional findInputWithName( const std::string& name, diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 2adacb976a042f..f7bc4a04cb6ce5 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -155,7 +155,7 @@ std::optional SchemaTypeParser::parseAliasAnnotation() { Symbol::fromQualString("alias::$" + std::to_string(next_id++))); alias_info.setIsWrite(true); } else { - return c10::nullopt; + return std::nullopt; } return alias_info; @@ -172,7 +172,7 @@ std::optional SchemaTypeParser::parseTensorDType( if (type != type_map.end()) { return type->second; } - return c10::nullopt; + return std::nullopt; } std::optional SchemaTypeParser::tryToParseDeviceType() { @@ -297,7 +297,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { // Parsing ranks, supports mix of sized and unsized ranks, or, just strided // ranks if (L.cur().kind == '*') { - dims.emplace_back(c10::nullopt); + dims.emplace_back(std::nullopt); L.next(); if (L.cur().kind == ':') { throw ErrorReport(L.cur()) << "Strides for unsized ranks not supported"; diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index 9295a3ed4007ab..db21737f4c4ba1 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -137,10 +137,10 @@ std::optional> ScriptTypeParser::parseBroadcastList( } if (expr.kind() != TK_SUBSCRIPT) - return c10::nullopt; + return std::nullopt; auto subscript = Subscript(expr); if (subscript.value().kind() != TK_VAR) - return c10::nullopt; + return std::nullopt; auto var = Var(subscript.value()); auto subscript_exprs = subscript.subscript_exprs(); @@ -151,10 +151,10 @@ std::optional> ScriptTypeParser::parseBroadcastList( TypePtr opt_type = OptionalType::create(broadcast_list->first); return std::pair(opt_type, broadcast_list->second); } else { - return c10::nullopt; + return std::nullopt; } } else if (var.name().name().find("BroadcastingList") != 0) { - return c10::nullopt; + return std::nullopt; } if (subscript_exprs.size() != 1) @@ -352,7 +352,7 @@ std::vector ScriptTypeParser::evaluateDefaults( CompilationUnit cu; cu.define( - c10::nullopt, + std::nullopt, /*properties=*/{}, /*propResolvers=*/{}, {def}, @@ -407,7 +407,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( auto decl_arg = *it; TypePtr type; - std::optional N = c10::nullopt; + std::optional N = std::nullopt; if (!decl_arg.type().present()) { // If this param doesn't have a type, default to "tensor" type = TensorType::getInferred(); @@ -421,7 +421,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( type = parseTypeFromExpr(decl_arg.type().get()); } } - std::optional default_value = c10::nullopt; + std::optional default_value = std::nullopt; if (decl_arg.defaultValue().present()) { default_value = *defaults_it++; } @@ -431,7 +431,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( N, default_value, decl_arg.kwarg_only(), - /*alias_info=*/c10::nullopt); + /*alias_info=*/std::nullopt); retval.push_back(arg); } return retval; @@ -455,8 +455,8 @@ std::vector ScriptTypeParser::parseReturnFromDecl(const Decl& decl) { return {Argument( "", parsed_type, - /*N =*/c10::nullopt, - /*default_value =*/c10::nullopt, + /*N =*/std::nullopt, + /*default_value =*/std::nullopt, /*kwarg_only =*/false)}; } FunctionSchema ScriptTypeParser::parseSchemaFromDef( diff --git a/torch/csrc/jit/frontend/source_range.cpp b/torch/csrc/jit/frontend/source_range.cpp index 20ffbfd4601e36..b1dfecbbf6418c 100644 --- a/torch/csrc/jit/frontend/source_range.cpp +++ b/torch/csrc/jit/frontend/source_range.cpp @@ -154,7 +154,7 @@ size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const { std::optional Source::findSourceRangeThatGenerated( const SourceRange& range) { if (!gen_ranges_) { - return c10::nullopt; + return std::nullopt; } return gen_ranges_->findSourceRangeThatGenerated(range); } diff --git a/torch/csrc/jit/frontend/source_range.h b/torch/csrc/jit/frontend/source_range.h index 1f8715ad009691..a8f22a800b022f 100644 --- a/torch/csrc/jit/frontend/source_range.h +++ b/torch/csrc/jit/frontend/source_range.h @@ -1,6 +1,6 @@ #pragma once #include -#include +#include #include #include @@ -190,7 +190,7 @@ struct TORCH_API Source { explicit Source( c10::string_view text_view, - std::optional filename = c10::nullopt, + std::optional filename = std::nullopt, size_t starting_line_no = 0, std::shared_ptr gen_ranges = nullptr, CopiesString copies_str = COPIES_STRING) @@ -210,7 +210,7 @@ struct TORCH_API Source { explicit Source( StringCordView str, - std::optional filename = c10::nullopt, + std::optional filename = std::nullopt, size_t starting_line_no = 0, std::shared_ptr gen_ranges = nullptr) : text_view_(std::move(str)), @@ -360,7 +360,7 @@ struct TORCH_API SourceRange { std::optional> file_line_col() const { if (!source_view_ || !source()->filename()) { - return c10::nullopt; + return std::nullopt; } auto lineno = source_view_->lineno_for_offset(start_); @@ -383,7 +383,7 @@ struct TORCH_API SourceRange { std::optional findSourceRangeThatGenerated() const { if (!source_view_) { - return c10::nullopt; + return std::nullopt; } return source_view_->findSourceRangeThatGenerated(*this); } diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 4b65903529d23a..94a11b21b1f22c 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -658,7 +658,7 @@ void IterableTree::addChild( // iterables run for the minimum length of all its leaves unroll_length_ = std::min(*child_len, *unroll_length_); } else { - unroll_length_ = c10::nullopt; + unroll_length_ = std::nullopt; } } children_.push_back(iter_value); diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 97b092cad3ce7a..1ca59ced6e68b8 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -1,7 +1,7 @@ #pragma once -#include #include #include +#include #include #include @@ -122,13 +122,13 @@ struct TORCH_API SugaredValue // to support containers of Heterogenous types, like Module Containers & // Tuples virtual std::optional staticLen() { - return c10::nullopt; + return std::nullopt; } // When iterating over this SugaredValue, should we emit the for loop as an // unrolled loop. bool shouldEmitUnrolled() { - return staticLen() != c10::nullopt; + return staticLen() != std::nullopt; } // return length of this thing, if not then it can't be iterated. @@ -323,7 +323,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { } auto sym = Symbol::fromQualString(name + "::" + field); - return std::make_shared(sym, c10::nullopt); + return std::make_shared(sym, std::nullopt); } private: @@ -506,7 +506,7 @@ struct TORCH_API PrintValue : public SugaredValue { // is a noop when the input is a subtype of 'type' struct TORCH_API CastValue : public BuiltinFunction { CastValue(TypePtr type, c10::Symbol method) - : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {} + : BuiltinFunction(method, std::nullopt), type_(std::move(type)) {} std::shared_ptr call( const SourceRange& loc, GraphFunction& m, @@ -638,7 +638,7 @@ struct TORCH_API RangeValue : SugaredValue { const SourceRange& loc, GraphFunction& m, std::vector input, - std::optional static_len = c10::nullopt); + std::optional static_len = std::nullopt); std::string kind() const override { return "range"; @@ -730,7 +730,7 @@ struct TORCH_API IterableTree : SugaredValue { TypePtr type_hint = nullptr) override; private: - std::optional unroll_length_ = c10::nullopt; + std::optional unroll_length_ = std::nullopt; std::vector children_; }; diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 9616e0f83dfbe2..a90d5bb897f454 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -818,8 +818,8 @@ void addInputs(Node* n, const char* name, std::optional value) { n, name, value.has_value() - ? c10::make_optional(value->guard_int(__FILE__, __LINE__)) - : c10::nullopt); + ? std::make_optional(value->guard_int(__FILE__, __LINE__)) + : std::nullopt); } void addInputs( diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index f9b2ed5dd7ce9d..6f674f30b90fca 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -105,7 +105,7 @@ class MutableTypePtrHelper { } } if (mutable_types.empty()) { - return c10::nullopt; + return std::nullopt; } return mutable_types; } @@ -121,7 +121,7 @@ class MutableTypePtrHelper { return {AliasTypeSet{ FutureType::create(*toSingleType(*maybe_mut_types))}}; } - return c10::nullopt; + return std::nullopt; } case TypeKind::AwaitType: { if (auto maybe_mut_types = mapTypeToAliasTypeSet( @@ -129,7 +129,7 @@ class MutableTypePtrHelper { return { AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}}; } - return c10::nullopt; + return std::nullopt; } case TypeKind::TupleType: { std::vector mutable_types; @@ -142,12 +142,12 @@ class MutableTypePtrHelper { } } if (mutable_types.empty()) { - return c10::nullopt; + return std::nullopt; } return {AliasTypeSet{TupleType::create(mutable_types)}}; } default: - return c10::nullopt; + return std::nullopt; } } ska::flat_hash_map* mutable_type_cache_; @@ -1896,7 +1896,7 @@ bool AliasDb::mayAliasWildcard(const at::ArrayRef vs) const { std::optional AliasDb::tryGetOrCreateWildcard(const TypePtr& type) { auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type); if (!maybe_mut_types) { - return c10::nullopt; + return std::nullopt; } auto mut_type = toSingleType(*maybe_mut_types); auto existing_wildcard = wildcardIndex_.find(*mut_type); @@ -1970,7 +1970,7 @@ std::optional AliasDb::setWildcard(const Value* v) { std::optional maybe_wildcardElement = tryGetOrCreateWildcard(v->type()); if (!maybe_wildcardElement) { - return c10::nullopt; + return std::nullopt; } // Ensure that we create a corresponding Element for `v` still, as it is an // invariant that all mutable values have an Element diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index ef697a5af76800..a0f8c8760a130b 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -69,7 +69,7 @@ std::optional tryInsertConstant( at::Tensor ref = val.toTensor(); if (!insertableTensor(val.toTensor())) { n->destroy(); - return c10::nullopt; + return std::nullopt; } if (!ref.defined()) { n->destroy(); @@ -99,7 +99,7 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return c10::nullopt; + return std::nullopt; } } else if (val.isString()) { n->s_(attr::value, val.toStringRef()); @@ -125,7 +125,7 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return c10::nullopt; + return std::nullopt; }; } else if (val.isObject()) { const auto& ref = val.toObjectRef(); @@ -137,14 +137,14 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return c10::nullopt; + return std::nullopt; } } else if ((val.isGenericDict() && insertableIValue(val)) || (val.isEnum())) { n->ival_(attr::value, val); n->output()->setType(val.type()); } else { n->destroy(); - return c10::nullopt; + return std::nullopt; } if (loc) n->setSourceRange(*loc); @@ -155,7 +155,7 @@ std::optional tryInsertConstant( std::optional toIValue(const Value* v) { if (v->node()->kind() != prim::Constant || v->type()->cast()) { - return c10::nullopt; + return std::nullopt; } const Node* node = v->node(); const TypePtr& type = v->type(); diff --git a/torch/csrc/jit/ir/constants.h b/torch/csrc/jit/ir/constants.h index 118da1e932d9c1..160dad5eab4c61 100644 --- a/torch/csrc/jit/ir/constants.h +++ b/torch/csrc/jit/ir/constants.h @@ -25,27 +25,27 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error { TORCH_API Value* insertConstant( Graph& g, const IValue& val, - std::optional loc = c10::nullopt, - std::optional scope = c10::nullopt); + std::optional loc = std::nullopt, + std::optional scope = std::nullopt); // note: prefer g.insertConsant(val, loc) which does exactly the same thing // this function is only declared/defined here because its implementation is // closely related to the implementation of prim::Constant that is also in // constants.cpp. // -// returns a c10::nullopt if the IValue kind cannot be inserted as a constant +// returns a std::nullopt if the IValue kind cannot be inserted as a constant TORCH_API std::optional tryInsertConstant( Graph& g, const IValue& val, - std::optional loc = c10::nullopt, - std::optional scope = c10::nullopt); + std::optional loc = std::nullopt, + std::optional scope = std::nullopt); //////////////////////////////////////////////////////////////////////////////// // Helper for retrieving constants //////////////////////////////////////////////////////////////////////////////// // attempt to convert a (possibly constant) Value* into an interpreter value -// (IValue). returns c10::nullopt if the Value* was not constant +// (IValue). returns std::nullopt if the Value* was not constant TORCH_API std::optional toIValue(const Value* v); // if a value is a constant then try to turn into type T using the @@ -55,7 +55,7 @@ std::optional constant_as(const Value* v) { if (auto ivalue = toIValue(v)) { return ivalue->to(); } - return c10::nullopt; + return std::nullopt; } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index a6b0116d7fb63f..3b449ea7ea21f7 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -412,7 +412,7 @@ std::ostream& operator<<(std::ostream& out, const Graph& g) { static void checkSameDevice(const Node* node) { bool has_device = false; - std::optional device = c10::nullopt; + std::optional device = std::nullopt; auto checkValue = [&](const Value* v) { if (TensorTypePtr type = v->type()->cast()) { if (type->device() && !has_device) { @@ -1297,7 +1297,7 @@ Node::Node(Graph* graph_, NodeKind kind_) graph_(graph_), owning_block_(nullptr), scope_(graph_->current_scope_), - callstack_(c10::nullopt), + callstack_(std::nullopt), op_(nullptr), topo_position_(0) { graph_->all_nodes.emplace(this); @@ -2101,11 +2101,11 @@ std::vector inlineCallTo( std::unordered_map new_callstack_entries; - std::optional module_instance_info = c10::nullopt; + std::optional module_instance_info = std::nullopt; if (to_replace->kind() == prim::CallMethod) { auto class_type_ptr = to_replace->input(0)->type()->cast(); if (to_replace->input(0)->node()->kind() == prim::GetAttr) { - module_instance_info = c10::make_optional(ModuleInstanceInfo( + module_instance_info = std::make_optional(ModuleInstanceInfo( class_type_ptr, to_replace->input(0)->node()->s(attr::name))); } else if ( !to_replace->owningGraph()->inputs().empty() && @@ -2113,11 +2113,11 @@ std::vector inlineCallTo( // This CallMethod must correspond to method of the same object // to which this graph belongs. module_instance_info = - c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF")); + std::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF")); } else { // Not sure if it is possible to come here ever. // TODO: Remove this else. Or add assert - module_instance_info = c10::make_optional( + module_instance_info = std::make_optional( ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN")); } } diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 859da3cb3cae99..3db67b2f9798ce 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -348,7 +348,7 @@ struct TORCH_API Node { // is changed, we need to rely on this name // to retrieve old schemas to successfully apply upgraders // for this operator. - std::optional historic_schema_name_ = c10::nullopt; + std::optional historic_schema_name_ = std::nullopt; protected: Node(Graph* graph_, NodeKind kind_); // defined after graph @@ -534,7 +534,7 @@ struct TORCH_API Node { if (auto v = get(name)) { return v->template to(); } - return c10::nullopt; + return std::nullopt; } // Returns true if the value of input name is statically known @@ -1368,8 +1368,8 @@ struct Graph : std::enable_shared_from_this { // Insert constant IValue into the graph. TORCH_API Value* insertConstant( const IValue& val, - std::optional loc = c10::nullopt, - std::optional scope = c10::nullopt); + std::optional loc = std::nullopt, + std::optional scope = std::nullopt); // Schema-driven insert: // This inserts a node into the graph with inputs determined from args and @@ -1733,14 +1733,14 @@ struct OperatorMap { std::optional find(const Operator& op) { const auto it = map.find(Symbol::fromQualString(op.schema().name())); if (it == map.end()) { - return c10::nullopt; + return std::nullopt; } for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { if (vit->first->schema() == op.schema()) { return vit->second; } } - return c10::nullopt; + return std::nullopt; } // TODO: return iterator @@ -1809,14 +1809,14 @@ struct FunctionSchemaMap { std::optional find(const FunctionSchema& schema) const { const auto it = map.find(Symbol::fromQualString(schema.name())); if (it == map.end()) { - return c10::nullopt; + return std::nullopt; } for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { if (vit->first == schema) { return vit->second; } } - return c10::nullopt; + return std::nullopt; } // TODO: return iterator diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index 54498030322384..5eafeb0fc4aac2 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -1,10 +1,10 @@ #pragma once #include #include -#include #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index f0dd562cc1cd27..3b8319ad8f90ee 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -408,7 +408,7 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { } // Loading the TS module is required for this backport, because bytecode needs // to be re-emitted (refer to the comments below) - Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); // The RAII guard to change the flag, emitBytecodeDefaultInputs, to true, so // that TS stores the default argument values in the constant table, and emits @@ -476,7 +476,7 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { } // Loading the TS module is required for this backport, because bytecode needs // to be re-emitted (refer to the comments below) - Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); // The RAII guard to change the flag, emit_default_input_instructions, to // false to keep the same behavior in bytecode version 6. Change the flag, @@ -502,7 +502,7 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) { ExtraFilesMap extra_files; Module torch_script = - torch::jit::load(input_model_stream, c10::nullopt, extra_files); + torch::jit::load(input_model_stream, std::nullopt, extra_files); std::stringstream intermediate_model_stream; // TODO(@pavithran) : Check if debug info is available and use load/save while // backporting hardcode debaug info to be false untill supported. @@ -540,7 +540,7 @@ std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) { extra_files.emplace(record.substr(found + 1), ""); } } - Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); std::stringstream intermediate_model_stream; { BytecodeEmitModeGuard argNumGuard( diff --git a/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h b/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h index 2e65f1f38bd8d2..d89165bb1d2950 100644 --- a/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h +++ b/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index bca40735891394..2094d4a87a1719 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -35,6 +34,7 @@ #include #include #include +#include #ifndef DISABLE_UPGRADER #include @@ -364,7 +364,7 @@ std::unique_ptr FlatbufferLoader::parseFunction( (operator_version < caffe2::serialize::kProducedFileFormatVersion); for (const auto* op : *method->operators()) { - std::optional num_args = c10::nullopt; + std::optional num_args = std::nullopt; if (op->num_args_serialized() > -1) { num_args = op->num_args_serialized(); } @@ -399,7 +399,7 @@ std::unique_ptr FlatbufferLoader::parseFunction( auto arg = c10::Argument( arg_tb->name()->str(), std::move(type_ptr), - c10::nullopt /*N*/, + std::nullopt /*N*/, std::move(default_value)); args.emplace_back(std::move(arg)); } diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index 9ac9636f3f14be..62b2c795bf84d4 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -9,8 +9,8 @@ #include #include #include -#include #include +#include /** * Defines the public API for loading flatbuffer-serialized mobile modules. @@ -58,7 +58,7 @@ using ExtraFilesMap = std::unordered_map; TORCH_API mobile::Module parse_and_initialize_mobile_module( void* data, size_t size, // of `data`, in bytes. - std::optional device = c10::nullopt, + std::optional device = std::nullopt, ExtraFilesMap* extra_files = nullptr, bool should_copy_tensor_memory = false); @@ -74,7 +74,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module( TORCH_API mobile::Module parse_and_initialize_mobile_module( std::shared_ptr data, size_t size, // of `data`, in bytes. - std::optional device = c10::nullopt, + std::optional device = std::nullopt, ExtraFilesMap* extra_files = nullptr); // Parse a mobile::Module from raw bytes, also returning JIT-related metadata. @@ -87,7 +87,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( size_t size, // of `data`, in bytes. ExtraFilesMap& jit_sources, std::vector& jit_constants, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, ExtraFilesMap* extra_files = nullptr); // Load a mobile::Module from a filepath. @@ -100,7 +100,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( // directly. TORCH_API mobile::Module load_mobile_module_from_file( const std::string& filename, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, ExtraFilesMap* extra_files = nullptr); TORCH_API uint64_t get_bytecode_version(std::istream& in); @@ -114,7 +114,7 @@ TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( // its entirity to a buffer TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( std::istream& in, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, ExtraFilesMap* extra_files = nullptr); TORCH_API mobile::Module parse_flatbuffer_no_object( diff --git a/torch/csrc/jit/mobile/frame.h b/torch/csrc/jit/mobile/frame.h index 45c51fef0085e5..4ad3817af624ec 100644 --- a/torch/csrc/jit/mobile/frame.h +++ b/torch/csrc/jit/mobile/frame.h @@ -2,8 +2,8 @@ #include -#include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 36f19fb1fac41c..9c3626e361da48 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -72,7 +72,7 @@ bool Function::initialize_operators(bool should_check_operators) { const auto& opname = code_.op_names_[i]; int num_args = code_.operator_input_sizes_[i]; std::optional num_specified_args = - num_args < 0 ? c10::nullopt : std::optional(num_args); + num_args < 0 ? std::nullopt : std::optional(num_args); auto func = makeOperatorFunction(opname, num_specified_args); if (!func.has_value()) { unsupported_op_names.insert(operator_str(opname)); @@ -189,7 +189,7 @@ std::optional> makeOperatorFunction( TORCH_CHECK(false, "arguments are missing for operator ", opname); } } else { - return c10::nullopt; + return std::nullopt; } } } diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index da7b87bae6110d..1fa2fe47904b56 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -23,6 +22,7 @@ #include #include #include +#include #include #include @@ -267,7 +267,7 @@ void BytecodeDeserializer::parseFunctionSchema( args.emplace_back( name, std::move(type), - c10::nullopt /*N*/, + std::nullopt /*N*/, std::move(default_value)); } tryRegisterMethod(args, *function); @@ -704,7 +704,7 @@ void _load_extra_only_for_mobile( // TODO: the current flatbuffers implementation will always load the // whole module including the extra files. Ideally it should be // possible to just get the extra files given data - load_mobile_module_from_file(filename, c10::nullopt, &extra_files); + load_mobile_module_from_file(filename, std::nullopt, &extra_files); break; } default: { diff --git a/torch/csrc/jit/mobile/import.h b/torch/csrc/jit/mobile/import.h index 77a801e62571df..73ebe18976d60c 100644 --- a/torch/csrc/jit/mobile/import.h +++ b/torch/csrc/jit/mobile/import.h @@ -45,15 +45,15 @@ TORCH_API mobile::Module _load_for_mobile( TORCH_API mobile::Module _load_for_mobile( std::istream& in, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); TORCH_API mobile::Module _load_for_mobile( const std::string& filename, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); TORCH_API mobile::Module _load_for_mobile( std::unique_ptr rai, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); /** * Load only the contents of the "extra/" files whose names are diff --git a/torch/csrc/jit/mobile/import_data.h b/torch/csrc/jit/mobile/import_data.h index 25e1fd81341c18..d2d2fa7f998e22 100644 --- a/torch/csrc/jit/mobile/import_data.h +++ b/torch/csrc/jit/mobile/import_data.h @@ -2,8 +2,8 @@ #include #include -#include #include +#include #include #include @@ -19,7 +19,7 @@ namespace jit { */ TORCH_API std::map _load_parameters( std::istream& in, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); /** * Loads named parameters from the serialized data in @p filename. @@ -28,7 +28,7 @@ TORCH_API std::map _load_parameters( */ TORCH_API std::map _load_parameters( const std::string& filename, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); // NOTE: Please prefer using _load_parameters over using the function below. TORCH_API std::map mobile_module_to_parameter_map( diff --git a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h index b6abe86c0fdca7..813d7be7e7a2a9 100644 --- a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h +++ b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h @@ -104,7 +104,7 @@ class MobileModelRunner { */ bool has_new_style_bundled_inputs() const { return module_->find_method("get_bundled_inputs_functions_and_info") != - c10::nullopt; + std::nullopt; } /** diff --git a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp index 585747c14d8210..3687f84f703971 100644 --- a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp +++ b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp @@ -117,10 +117,10 @@ void call_dependent_methods(std::set& root_ops) { if (is_training && has_batchnorm) { at::batch_norm( at::ones({2, 2}), - c10::nullopt, - c10::nullopt, - c10::nullopt, - c10::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, true, 0.1, 0.1, diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 23dfe9ff367852..bcf4e5e1f6ba7c 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -90,10 +90,10 @@ void Module::unsafeCopyMethod( std::optional Module::find_method(const std::string& basename) const { for (const auto& fn : cu_->methods()) { if (fn->name() == basename) { - return c10::make_optional(Method(this, fn.get())); + return std::make_optional(Method(this, fn.get())); } } - return c10::nullopt; + return std::nullopt; } namespace { @@ -324,7 +324,7 @@ static std::optional print_type(const c10::Type& t) { if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return c10::nullopt; + return std::nullopt; } TORCH_API ModuleInfo get_module_info(const mobile::Module& module) { diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 8e49749042424c..1d9d6fb3abcfaf 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -118,7 +118,7 @@ void toPrimDType(Stack& stack) { pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); - std::optional device = c10::nullopt; + std::optional device = std::nullopt; at::Tensor self = pop(stack).toTensor(); push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); } diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index 3185b0eaf123ca..21889a84b44070 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -122,7 +122,7 @@ std::shared_ptr create_upgrader_graph( const std::string& upgrader_name, const std::string& upgrader_body) { auto cu = std::make_shared(); - cu->define(c10::nullopt, upgrader_body, nativeResolver(), nullptr); + cu->define(std::nullopt, upgrader_body, nativeResolver(), nullptr); Function& jitFunc = cu->get_function(upgrader_name); GraphFunction& graphFunction = toGraphFunction(jitFunc); return graphFunction.graph(); diff --git a/torch/csrc/jit/operator_upgraders/utils.cpp b/torch/csrc/jit/operator_upgraders/utils.cpp index fef7b92c83c95a..98819b08d640b4 100644 --- a/torch/csrc/jit/operator_upgraders/utils.cpp +++ b/torch/csrc/jit/operator_upgraders/utils.cpp @@ -1,9 +1,9 @@ #include -#include #include #include #include +#include #include #include #include @@ -27,7 +27,7 @@ std::optional findUpgrader( if (pos != upgraders_for_schema.end()) { return *pos; } - return c10::nullopt; + return std::nullopt; } bool isOpCurrentBasedOnUpgraderEntries( diff --git a/torch/csrc/jit/operator_upgraders/utils.h b/torch/csrc/jit/operator_upgraders/utils.h index a30b8c1182b9cf..95e794261e6b97 100644 --- a/torch/csrc/jit/operator_upgraders/utils.h +++ b/torch/csrc/jit/operator_upgraders/utils.h @@ -1,8 +1,8 @@ #pragma once #include -#include #include #include +#include #include #include diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 635162e0495319..bbd56744afb7d8 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -4,10 +4,10 @@ #include #include #include -#include #include #include #include +#include #include #include @@ -65,7 +65,7 @@ std::optional parseAutocast( const AutocastContext& context) { if (!isAutocastNode(value)) { // Not an autocast... - return c10::nullopt; + return std::nullopt; } if (value->node()->kind() == prim::CreateObject) { AutocastScope scope; @@ -135,7 +135,7 @@ std::optional parseAutocast( AT_ERROR("Unsupported autocast syntax"); } - return c10::nullopt; + return std::nullopt; } void castTensorInputs( @@ -269,7 +269,7 @@ void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) { void handleBlock(Block* block, AutocastContext initial_state) { std::stack autocast_stack; - std::optional incompatible_amp = c10::nullopt; + std::optional incompatible_amp = std::nullopt; // The current autocast enabled/disabled state auto current_state = [&] { diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 20a883a8d06fdd..2aa6aff76bc1d7 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -144,7 +144,7 @@ bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) { std::optional firstOrLastUse(Value* v, bool find_first) { if (v->uses().empty()) { - return c10::nullopt; + return std::nullopt; } Use extreme_use = v->uses()[0]; for (size_t i = 1; i < v->uses().size(); ++i) { @@ -176,12 +176,12 @@ static std::vector sort_indexes(at::ArrayRef values) { // if neither has any uses, use original ordering. Since the // only values that jitter are ones added by the compiler and are guaranteed // to have uses, original ordering is fine. - if (first_uses[i1] == c10::nullopt && first_uses[i2] == c10::nullopt) { + if (first_uses[i1] == std::nullopt && first_uses[i2] == std::nullopt) { return i1 < i2; } - if (first_uses[i1] == c10::nullopt) { + if (first_uses[i1] == std::nullopt) { return false; - } else if (first_uses[i2] == c10::nullopt) { + } else if (first_uses[i2] == std::nullopt) { return true; } diff --git a/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp b/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp index 72d419eeb9c163..b3e190445b8fe3 100644 --- a/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp +++ b/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp @@ -26,14 +26,14 @@ static std::optional> getChunkOutputs(Node* chunk) { // number of chunks if (static_cast(list_use.user->outputs().size()) != chunk->get(attr::chunks).value()) { - return c10::nullopt; + return std::nullopt; } auto unpack_outputs = list_use.user->outputs(); for (const auto i : c10::irange(unpack_outputs.size())) { outputs.emplace_back(unpack_outputs[i], i); } } else { - return c10::nullopt; + return std::nullopt; } } return outputs; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 6334cd75faa903..5ec8b561cba80a 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -28,14 +28,14 @@ std::optional> runNodeIfInputsAreConstant( if (auto ival = toIValue(input)) { stack.push_back(*ival); } else { - return c10::nullopt; + return std::nullopt; } } switch (n->kind()) { case prim::ListUnpack: { if (stack.back().toList().size() != n->outputs().size()) { - return c10::nullopt; + return std::nullopt; } listUnpack(stack, n->outputs().size()); } break; @@ -78,14 +78,14 @@ std::optional> runNodeIfInputsAreConstant( // vararg schemas require the number of inputs at the top of the stack // but this is broken in other places in constant prop, so disable it // for now - return c10::nullopt; + return std::nullopt; } try { auto op = n->getOperation(); op(stack); } catch (...) { - return c10::nullopt; + return std::nullopt; } } break; } @@ -95,13 +95,13 @@ std::optional> runNodeIfInputsAreConstant( const at::Tensor& t = v.toTensor(); if (t.defined() && t.requires_grad()) { // requires grad tensors cannot be constants - return c10::nullopt; + return std::nullopt; } } // Weak form of const propagation if (ignore_custom_classes) { if (v.isCustomClass()) { - return c10::nullopt; + return std::nullopt; } } // see [Constant Object Weak CompilationUnit Reference] @@ -123,7 +123,7 @@ std::optional> runNodeIfInputsAreConstant( } if (v.isObject()) { if (!v.toObject()->is_weak_compilation_ref()) { - return c10::nullopt; + return std::nullopt; } } } diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index c5fe65537669a8..46eca6f2b221f5 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -287,7 +287,7 @@ class SubgraphSlicer { aliasDb_.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { - return c10::nullopt; + return std::nullopt; } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( @@ -305,11 +305,11 @@ class SubgraphSlicer { std::optional getProfileNodeRequiresGrad(Node* n) { TORCH_INTERNAL_ASSERT(n->kind() == prim::profile); if (!n->hasAttribute(attr::profiled_type)) { - return c10::nullopt; + return std::nullopt; } auto& type = n->ty(attr::profiled_type); if (type->castRaw() == nullptr) { - return c10::nullopt; + return std::nullopt; } return type->expectRef().requiresGrad(); } @@ -403,7 +403,7 @@ std::optional findRequiresGradForOutput( } } - return c10::nullopt; + return std::nullopt; } void AddRequiresGradToDifferentiableGraph( diff --git a/torch/csrc/jit/passes/device_type_analysis.cpp b/torch/csrc/jit/passes/device_type_analysis.cpp index 7670292696ae69..c9c9188d37dc5e 100644 --- a/torch/csrc/jit/passes/device_type_analysis.cpp +++ b/torch/csrc/jit/passes/device_type_analysis.cpp @@ -2,12 +2,12 @@ #include #include #include -#include #include #include #include #include #include +#include #include namespace torch { @@ -88,7 +88,7 @@ bool propWithNoDevice(Node* n) { } if (input_num == n->inputs().size()) { // No tensor found - return setReturnsToDevice(n, c10::nullopt); + return setReturnsToDevice(n, std::nullopt); } auto tensor_type = n->inputs()[input_num]->type()->expect(); @@ -108,7 +108,7 @@ bool propWithNoDevice(Node* n) { only_seen_cpu_zerodim = false; } else { // Bail on the type not match case - return setReturnsToDevice(n, c10::nullopt); + return setReturnsToDevice(n, std::nullopt); } } } diff --git a/torch/csrc/jit/passes/dtype_analysis.cpp b/torch/csrc/jit/passes/dtype_analysis.cpp index f63ea6f3419489..2311cb791a449c 100644 --- a/torch/csrc/jit/passes/dtype_analysis.cpp +++ b/torch/csrc/jit/passes/dtype_analysis.cpp @@ -3,13 +3,13 @@ #include #include #include -#include #include #include #include #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -102,7 +102,7 @@ static bool canBeInferredWithMetaTensor(Node* n) { std::optional inferWithMetaTensor(Node* n) { GRAPH_DEBUG("inferWithMetaTensor", getHeader(n)); if (!canBeInferredWithMetaTensor(n)) { - return c10::nullopt; + return std::nullopt; } Operation op = n->getOperation(); try { @@ -116,7 +116,7 @@ std::optional inferWithMetaTensor(Node* n) { } catch (...) { GRAPH_DEBUG("caught exception with Metatensor run!"); }; - return c10::nullopt; + return std::nullopt; } bool setDtype( diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 540f1a7e13fb84..ccafee9aa4ae43 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -41,7 +41,7 @@ void EraseNumberTypesOnBlock(Block* block) { WithInsertPoint guard(*it); Value* r = block->owningGraph()->insertConstant( - scalar_to_tensor(s), c10::nullopt, it->scope()); + scalar_to_tensor(s), std::nullopt, it->scope()); r->copyMetadata(it->output()); it->output()->replaceAllUsesWith(r); it.destroyCurrent(); diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 4d67d5d2178134..23bc873addc714 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -170,7 +170,7 @@ class AttributePropagator { std::optional resolveName(const std::string& name) { auto sub_names = splitName(name); if (sub_names.empty()) { - return c10::nullopt; + return std::nullopt; } auto& attr_name = sub_names.back(); auto cur_module = module_; @@ -189,7 +189,7 @@ class AttributePropagator { } } if (!found) { - return c10::nullopt; + return std::nullopt; } } @@ -207,7 +207,7 @@ class AttributePropagator { return std::make_pair(std::move(cur_module), std::move(attr_name)); } - return c10::nullopt; + return std::nullopt; } bool _loadModulePath(Value* input, std::shared_ptr& graph) { @@ -230,7 +230,7 @@ class AttributePropagator { std::shared_ptr& graph) { bool success = _loadModulePath(input, graph); if (!success) { - return c10::nullopt; + return std::nullopt; } return names_; } diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index c28e99a445258a..b508cd905c586b 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -1105,7 +1105,7 @@ class MKLDNNSubgraphSlicer { aliasDb_.moveAfterTopologicallyValid(consumer, producer); if (!canMerge) { - return c10::nullopt; + return std::nullopt; } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 98487830726216..5136615cd2e441 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -494,7 +494,7 @@ struct GraphFuser { AT_ASSERT(group->kind() == prim::FusionGroup); auto it = std::find(group->inputs().begin(), group->inputs().end(), input); if (it == group->inputs().end()) { - return c10::nullopt; + return std::nullopt; } size_t input_index = it - group->inputs().begin(); auto& subgraph = getSubgraph(group); @@ -505,7 +505,7 @@ struct GraphFuser { AT_ASSERT(subgraph_input->uses().size() == 1); return node; } - return c10::nullopt; + return std::nullopt; } void fuseChunkByReusingExistingFusedChunk( diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp index edb9f5b9589a06..430dbb3fd1c851 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -287,7 +287,7 @@ bool isClampFusable( vmap.find("output_max") != vmap.end(), "Expected to find output_max as well given " "output_min exist in pattern graph."); - // If output_min/max are not constant, we get c10::nullopt. + // If output_min/max are not constant, we get std::nullopt. auto output_min = graph_rewrite_helper::getIValue("output_min", match_vmap, vmap); auto output_max = diff --git a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp index f8d63e87f07b7e..226826e946098a 100644 --- a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp @@ -68,7 +68,7 @@ graph_node_list::iterator scanNode(Node* node, size_t threshold) { // so the profiles will have outdated requires_grad=False. // conservatively update them to maybe requiring grad, bc we might create // autodiff graphs when the tensors maybe require grad - UpdateDifferentiableGraphRequiresGrad(subgraph, c10::nullopt); + UpdateDifferentiableGraphRequiresGrad(subgraph, std::nullopt); SubgraphUtils::unmergeSubgraph(node); return next_node; } diff --git a/torch/csrc/jit/passes/integer_value_refinement.cpp b/torch/csrc/jit/passes/integer_value_refinement.cpp index 16a329b3b11f34..cf9b577f927b28 100644 --- a/torch/csrc/jit/passes/integer_value_refinement.cpp +++ b/torch/csrc/jit/passes/integer_value_refinement.cpp @@ -93,7 +93,7 @@ struct IntegerValueRefiner { auto other_output = other_if_block->outputs().at(i); auto other_const_value = other_output->type()->cast() ? constant_as(other_output) - : c10::nullopt; + : std::nullopt; if (!other_const_value || block_output->node()->kind() == prim::Constant) { continue; @@ -211,7 +211,7 @@ struct IntegerValueRefiner { return maybe_refinement->second; } } - return c10::nullopt; + return std::nullopt; } std::shared_ptr graph_; diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 4eeba79aae90c4..61d97057c5b429 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -5,9 +5,9 @@ #include #include -#include #include #include +#include namespace torch { namespace jit { @@ -72,15 +72,15 @@ std::optional runTorchSlice_opset9( TORCH_WARN( "Constant folding - Invalid number of inputs found for opset 9 " "onnx::Slice op. Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { - return c10::nullopt; + return std::nullopt; } auto startsAttr = node->is(attr::starts); auto endsAttr = node->is(attr::ends); if (startsAttr.size() != endsAttr.size()) { - return c10::nullopt; + return std::nullopt; } std::vector axesAttr; if (node->hasAttributeS("axes")) { @@ -98,7 +98,7 @@ std::optional runTorchSlice_opset9( handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); int64_t length = end - start; if (length < 0 || start > updated_val.sizes()[axis] - length) - return c10::nullopt; + return std::nullopt; updated_val = at::narrow(updated_val, axis, start, length); } return std::optional(updated_val); @@ -114,7 +114,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid number of inputs found for opset opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } // Checking validity of 'starts' and 'ends' input if (inputTensorValues[1].sizes().size() != 1 || @@ -122,12 +122,12 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'starts' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0]) { // Number of elements of 'starts' and 'ends' 1-D input tensors should be the // same - return c10::nullopt; + return std::nullopt; } // Checking 'axes' input, if available. std::vector axes; @@ -136,7 +136,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'axes' input found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0]) { // Number of elements of 'axes' and 'ends' 1-D input tensors should be the @@ -144,7 +144,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'axes' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } auto axes_a = inputTensorValues[3].accessor(); axes.resize(inputTensorValues[3].sizes()[0]); @@ -162,7 +162,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'steps' input found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0]) { // Number of elements of 'steps' and 'ends' 1-D input tensors should be @@ -170,7 +170,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'steps' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } auto steps_a = inputTensorValues[4].accessor(); for (const auto i : c10::irange(inputTensorValues[4].sizes()[0])) { @@ -179,7 +179,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } } } @@ -192,7 +192,7 @@ std::optional runTorchSlice_opset10( handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); int64_t length = end - start; if (length < 0 || start > updated_val.sizes()[axis] - length) - return c10::nullopt; + return std::nullopt; updated_val = at::narrow(updated_val, axis, start, length); } return std::optional(updated_val); @@ -272,11 +272,11 @@ std::optional runTorchBackendForOnnx( } else { TORCH_WARN( "Constant folding - unsupported opset version. Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } } else if (node->kind() == onnx::Concat) { if (!node->hasAttributeS("axis")) { - return c10::nullopt; + return std::nullopt; } updated_val = at::cat(at::TensorList(inputTensorValues), node->i(attr::axis)); @@ -310,7 +310,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Unsqueeze op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } auto axes_a = inputTensorValues[1].accessor(); std::vector axes; @@ -332,7 +332,7 @@ std::optional runTorchBackendForOnnx( } else if (opset_version >= ONNX_OPSET_9) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { - return c10::nullopt; + return std::nullopt; } updated_val = inputTensorValues[0]; std::vector axesAttr = node->is(attr::axes); @@ -345,7 +345,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - unsupported opset version. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } } else if (node->kind() == onnx::Squeeze) { assert(inputTensorValues.size() == 2 || inputTensorValues.size() == 1); @@ -359,7 +359,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Squeeze op. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } auto axes_a = inputTensorValues[1].accessor(); std::vector axes; @@ -389,12 +389,12 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - unsupported opset version. " "Constant folding not applied."); - return c10::nullopt; + return std::nullopt; } } else if (node->kind() == onnx::Transpose) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("perm")) { - return c10::nullopt; + return std::nullopt; } updated_val = inputTensorValues[0].permute(node->is(attr::perm)); return std::optional(updated_val); @@ -405,7 +405,7 @@ std::optional runTorchBackendForOnnx( ONNXTypeToATenType(node->i(attr::to)).value()); return std::optional(updated_val); } - return c10::nullopt; + return std::nullopt; } else if (node->kind() == onnx::Reshape) { assert(inputTensorValues.size() == 2); updated_val = inputTensorValues[0]; @@ -441,10 +441,10 @@ std::optional runTorchBackendForOnnx( } else if (node->kind() == onnx::ReduceL1 || node->kind() == onnx::ReduceL2) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { - return c10::nullopt; + return std::nullopt; } if (!node->hasAttributeS("keepdims")) { - return c10::nullopt; + return std::nullopt; } int p = node->kind() == onnx::ReduceL1 ? 1 : 2; updated_val = at::norm( @@ -485,7 +485,7 @@ std::optional runTorchBackendForOnnx( // at::index_select only supports indices with rank <= 1. // See https://pytorch.org/docs/main/generated/torch.index_select.html if (q > 1) { - return c10::nullopt; + return std::nullopt; } // If the device of indices tensor is not the same with it of the input // tensor, move it to the device of the input tensor @@ -539,7 +539,7 @@ std::optional runTorchBackendForOnnx( updated_val = at::softmax(inputTensorValues[0], axis); return std::optional(updated_val); } else { - return c10::nullopt; + return std::nullopt; } } @@ -652,7 +652,7 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { } auto updatedValWrapped = onnx_constant_fold::runTorchBackendForOnnx( node, inputTensorValues, opset_version); - if (updatedValWrapped == c10::nullopt) { + if (updatedValWrapped == std::nullopt) { // Constant folding is not supported for this op. Skip it. continue; } diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index 201c3def32685a..d25ebee32a787e 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -2,8 +2,8 @@ #include -#include #include +#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index f9c96d0430df02..99c801dcf77367 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -34,14 +34,14 @@ bool ConstantValueMap::HasRank(const std::string& tensorName) { std::optional ConstantValueMap::GetRank(const std::string& tensorName) { if (!HasRank(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().rankMap[tensorName]; } void ConstantValueMap::SetAllGraphInputsStatic(bool all_static) { ConstantValueMap::getInstance().allGraphInputsStatic = - c10::make_optional(all_static); + std::make_optional(all_static); } std::optional ConstantValueMap::GetAllGraphInputsStatic() { @@ -71,7 +71,7 @@ bool ConstantValueMap::HasShape(const std::string& tensorName) { std::optional ConstantValueMap::GetShape( const std::string& tensorName) { if (!HasShape(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().shapeMap[tensorName]; } @@ -90,7 +90,7 @@ bool ConstantValueMap::HasValue(const std::string& tensorName) { std::optional ConstantValueMap::GetValue( const std::string& tensorName) { if (!HasValue(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().tensorValueMap[tensorName]; } @@ -121,7 +121,7 @@ std::optional> ConstantValueMap::GetShapeInto1DInt64Vector( return shape_value; } } - return c10::nullopt; + return std::nullopt; } std::optional> ConstantValueMap:: @@ -152,7 +152,7 @@ std::optional> ConstantValueMap:: } } } - return c10::nullopt; + return std::nullopt; } // accessor for 1DInt64 case. @@ -183,7 +183,7 @@ bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) { std::optional ConstantValueMap::GetTypeReliable( const std::string& tensorName) { if (!HasTypeReliable(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().typeReliableMap[tensorName]; } @@ -202,7 +202,7 @@ bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) { std::optional ConstantValueMap::GetUseInferredType( const std::string& tensorName) { if (!HasUseInferredType(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().useInferredTypeMap[tensorName]; } @@ -221,7 +221,7 @@ bool ConstantValueMap::HasShapeValue(const std::string& tensorName) { std::optional ConstantValueMap::GetShapeValue( const std::string& tensorName) { if (!HasShapeValue(tensorName)) { - return c10::nullopt; + return std::nullopt; } return ConstantValueMap::getInstance().shapeValueMap[tensorName]; } @@ -284,7 +284,7 @@ void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().inferredShapeData.clear(); ConstantValueMap::getInstance().symbolDimMap.clear(); ConstantValueMap::getInstance().dimSymbolMap.clear(); - ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt; + ConstantValueMap::getInstance().allGraphInputsStatic = std::nullopt; ConstantValueMap::getInstance().allGraphInputsReliableComputed = false; } diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index c545c7aba823a1..febf412e5d1224 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -225,16 +225,16 @@ std::optional FunctionExtractor::FunctionContext::FindAttrName( auto v_it = scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0)); if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) { - return c10::nullopt; + return std::nullopt; } auto* n_in_def = v_it->second->node(); auto n_attr_it = node_attr_to_name_.find(n_in_def); if (n_attr_it == node_attr_to_name_.end()) { - return c10::nullopt; + return std::nullopt; } auto name_it = n_attr_it->second.find(attr.toUnqualString()); if (name_it == n_attr_it->second.end()) { - return c10::nullopt; + return std::nullopt; } return name_it->second; } @@ -301,7 +301,7 @@ std::optional FunctionExtractor::FindCommonAncestor( ScopePtr a, ScopePtr b) { if (!IsValidScope(a) || !IsValidScope(b)) { - return c10::nullopt; + return std::nullopt; } auto diff = @@ -327,20 +327,20 @@ std::optional FunctionExtractor::FindCommonAncestor( } } - return c10::nullopt; + return std::nullopt; } std::optional FunctionExtractor::FindCommonAncestor( const scope_list& scopes) { if (scopes.empty()) { - return c10::nullopt; + return std::nullopt; } std::optional common_ancestor = scopes.at(0); for (const auto& scope : scopes) { common_ancestor = FindCommonAncestor(common_ancestor.value(), scope); if (!common_ancestor.has_value()) { - return c10::nullopt; + return std::nullopt; } } @@ -410,7 +410,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { } } - return c10::nullopt; + return std::nullopt; } std::shared_ptr FunctionExtractor::ConstructFuncGraph( diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index b28de0fdee4cd5..6a1e3b08f3b9a8 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -52,7 +52,7 @@ std::deque findSubModuleAttr( Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { auto schema = function->getSchema(); auto args = schema.arguments(); - args.emplace_back(name, nullptr, c10::nullopt, attr); + args.emplace_back(name, nullptr, std::nullopt, attr); auto new_schema = FunctionSchema( schema.name(), schema.overload_name(), diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp index 6c064b70ae614f..cd975d0375fcbb 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp @@ -46,7 +46,7 @@ Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) { aten::slice, {index, graph->insertConstant( - scalar_to_tensor(at::Scalar(0)), c10::nullopt, slice->scope()), + scalar_to_tensor(at::Scalar(0)), std::nullopt, slice->scope()), start, end, step}); diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp index 6110954990455b..7a98567a529bee 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp @@ -84,7 +84,7 @@ std::optional EncapsulatePatternIntoSubblock(Node* n) { return EncapsulateInplaceIndexPutForONNX(n); } } - return c10::nullopt; + return std::nullopt; } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index b468e739a03f3d..18c31ea656610d 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -16,7 +16,7 @@ #include #endif -#include +#include #if defined(_MSC_VER) #include @@ -105,14 +105,14 @@ std::optional fusibleExpandTo( at::IntArrayRef from, at::IntArrayRef to) { if (from.size() > to.size()) { - return c10::nullopt; + return std::nullopt; } for (const auto i : c10::irange(from.size())) { auto fdim = from[from.size() - 1 - i]; auto tdim = to[to.size() - 1 - i]; if (fdim != 1 && fdim != tdim) { - return c10::nullopt; + return std::nullopt; } } @@ -168,7 +168,7 @@ void fuseBroadcast(Block* b) { .sizes() .concrete_sizes() .value()); // to - if (axis == c10::nullopt) { + if (axis == std::nullopt) { continue; } diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 427e5771a9f0f7..009566499275b4 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -100,7 +100,7 @@ static bool IsImplicitCastSupported(const NodeKind& nodeKind) { static std::optional PromoteScalarTypes( const std::vector& types) { if (types.empty()) { - return c10::nullopt; + return std::nullopt; } auto st = types[0]; for (const auto i : c10::irange(1, types.size())) { @@ -131,9 +131,9 @@ static std::optional PromoteScalarTypesWithCategory( return 0; }; - if (c10::nullopt == typeFromScalar) { + if (std::nullopt == typeFromScalar) { return typeFromTensor; - } else if (c10::nullopt == typeFromTensor) { + } else if (std::nullopt == typeFromTensor) { return typeFromScalar; } @@ -155,7 +155,7 @@ static std::optional InferExpectedScalarType(const Node* n) { if (auto* tensor_type = input->type()->castRaw()) { return tensor_type->scalarType(); } - return c10::nullopt; + return std::nullopt; }; auto emplace_type_from_scalar = [&typesFromTensors, &typesFromScalars](at::ScalarType scalar_type) { @@ -252,7 +252,7 @@ static std::optional InferExpectedScalarType(const Node* n) { } }); - std::optional st = c10::nullopt; + std::optional st = std::nullopt; const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { @@ -313,7 +313,7 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); auto input_scalar_type = - input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; + input_tensor_type ? input_tensor_type->scalarType() : std::nullopt; // We skip the 'condition' input (i.e., the first input) in case of // onnx:Where operator. @@ -393,7 +393,7 @@ static void RecoverScalarTypeForOutput( static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) { TORCH_INTERNAL_ASSERT(n->outputs().size() == 1); if (n->output()->type()->cast() == nullptr || - n->output()->type()->cast()->scalarType() == c10::nullopt) { + n->output()->type()->cast()->scalarType() == std::nullopt) { // skip LowPrecisionCast if op output type is null. return; } @@ -401,7 +401,7 @@ static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) { n->output()->type()->cast()->scalarType().value(); for (size_t i = 0; i < n->inputs().size(); ++i) { if (n->input(i)->type()->cast() == nullptr || - n->input(i)->type()->cast()->scalarType() == c10::nullopt) { + n->input(i)->type()->cast()->scalarType() == std::nullopt) { // skip LowPrecisionCast if any op input type node is null. return; } diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 65d065adeb2b57..3691f0bf7b09ce 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -98,7 +98,7 @@ c10::ShapeSymbol ONNXDimToShapeSymbol( if (dim.has_dim_value()) { return c10::ShapeSymbol::fromStaticSize(dim.dim_value()); } - std::optional sym = c10::nullopt; + std::optional sym = std::nullopt; if (dim.has_dim_param()) { // If this param is already known, assign the same Symbol. GRAPH_UPDATE("Got dim_param:", dim.dim_param()); @@ -267,7 +267,7 @@ Value* CloneValueFromListConstruct( // is preserved. If the elemtype is Int, insert a onnx::Concat node into // the graph. TypePtr elem = v->type()->castRaw()->getElementType(); - std::optional scalar_type = c10::nullopt; + std::optional scalar_type = std::nullopt; if (elem->cast()) { scalar_type = at::kLong; if (isValidToTransformToONNXConcatNode(v->node())) { @@ -332,7 +332,7 @@ Node* CloneNodeToGraph( // Try to lookup input value and insert it into the graph. // If the input value is unknown, set it to graph input in the new // graph, and copy over metadata, such as datatype and shape. - ::std::optional val = ::c10::nullopt; + ::std::optional val = ::std::nullopt; auto v0 = params_dict.find(v->debugName()); if (v0 != params_dict.end()) { val = v0->second.toTensor(); @@ -420,13 +420,13 @@ void ConvertGraphToONNXProto( std::optional ComputeConstantFolding(Node* n, int opset_version) { if (n->inputs().empty()) { - return c10::nullopt; + return std::nullopt; } std::vector inputTensorValues; for (auto i : c10::irange(n->inputs().size())) { if (TensorTypePtr input_type = n->input(i)->type()->cast()) { if (!ConstantValueMap::HasValue(n->input(i)->debugName())) { - return c10::nullopt; + return std::nullopt; } auto tensor_value = ConstantValueMap::GetValue(n->input(i)->debugName()).value(); @@ -434,7 +434,7 @@ std::optional ComputeConstantFolding(Node* n, int opset_version) { } } if (inputTensorValues.size() < n->inputs().size()) { - return c10::nullopt; + return std::nullopt; } try { return onnx_constant_fold::runTorchBackendForOnnx( @@ -443,7 +443,7 @@ std::optional ComputeConstantFolding(Node* n, int opset_version) { auto ex_str = std::string(ex.what()); ex_str = ex_str.substr(0, ex_str.find('\n')); TORCH_WARN("Constant folding in symbolic shape inference fails: ", ex_str); - return c10::nullopt; + return std::nullopt; } } @@ -500,7 +500,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( std::numeric_limits::max() / input_shape.static_size()) { TORCH_WARN( "ComputeShapeFromReshape(), shape_ratio overflows, skip shape inference."); - return c10::nullopt; + return std::nullopt; } else { shape_ratio *= static_cast(input_shape.static_size()); } @@ -523,7 +523,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( } else { auto value = target_shape.value(); if (sym_map.find(value) == sym_map.end()) { - return c10::nullopt; + return std::nullopt; } sym_map[value]--; if (sym_map[value] == 0) { @@ -535,7 +535,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( // sym_map is used to match shape symbols between the input and shape. // If there is a mismatch, the output shape cannot be estimated. if (!sym_map.empty()) { - return c10::nullopt; + return std::nullopt; } TORCH_INTERNAL_ASSERT( @@ -565,7 +565,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromExpand( const std::vector& reshape) { for (const auto& it : reshape) { if (it < 0) { - return c10::nullopt; + return std::nullopt; } } std::vector<::c10::ShapeSymbol> final_shape; @@ -607,7 +607,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromTile( "ONNX Tile input shapes do not match."); for (const auto& it : reshape) { if (it < 0) { - return c10::nullopt; + return std::nullopt; } } std::vector<::c10::ShapeSymbol> final_shape; @@ -688,7 +688,7 @@ std::optional> GetValueFromListConstructNode( } return lc_node->inputs().size() == shape_size.size() ? std::optional>(shape_size) - : c10::nullopt; + : std::nullopt; } void SetShapeValueFromListConstructNode(Node* lc_node) { diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 7390bea56e77b0..d889295dca19e2 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -655,11 +655,11 @@ void UnpackQuantizedTensorInputs(std::shared_ptr& graph) { auto input_scale = graph->insertInput(index + 1, input_name + "_scale") ->setType(TensorType::create( - at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); + at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt)); auto input_zero_point = graph->insertInput(index + 2, input_name + "_zero_point") ->setType(TensorType::create( - at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); + at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt)); std::vector converted{input_value, input_scale, input_zero_point}; auto input_tuple = graph->prependNode(graph->createTuple(converted))->output(); diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.cpp b/torch/csrc/jit/passes/peephole_dict_idioms.cpp index d3a5cfa36261b0..171b787d17b048 100644 --- a/torch/csrc/jit/passes/peephole_dict_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_dict_idioms.cpp @@ -34,7 +34,7 @@ class DictNodeImpl : public DictNodeImplBase { auto key_opt = toIValue(dict_creation_node->input(i)); // Key is not constant if we cannot convert to IValue - if (key_opt == c10::nullopt) { + if (key_opt == std::nullopt) { has_non_const_key_ = true; continue; } @@ -129,7 +129,7 @@ class DictNode { if (impl_ && impl_->contains(key)) { return impl_->get(key); } - return c10::nullopt; + return std::nullopt; } private: @@ -185,14 +185,14 @@ class PeepholeOptimizeDictIdiomsImpl { const DictNode& dict_node = getDictNode(dict_creation_node); auto key_opt = toIValue(key); // Key is not constant if we cannot convert to IValue - if (key_opt == c10::nullopt) { - return c10::nullopt; + if (key_opt == std::nullopt) { + return std::nullopt; } IValue key_ival = *key_opt; if (dict_node.canOptimize()) { return dict_node.getOrNullopt(key_ival); } - return c10::nullopt; + return std::nullopt; } std::optional computeLen(Node* dict_creation_node) { @@ -200,13 +200,13 @@ class PeepholeOptimizeDictIdiomsImpl { if (dict_node.canOptimize()) { return static_cast(dict_node.size()); } - return c10::nullopt; + return std::nullopt; } bool optimizeLen(Node* len_node, Node* creation_node) { if (creation_node->kind() == prim::DictConstruct) { auto len = computeLen(creation_node); - if (len != c10::nullopt) { + if (len != std::nullopt) { WithInsertPoint guard(len_node); len_node->output()->replaceAllUsesWith(graph_->insertConstant(len)); return true; @@ -219,7 +219,7 @@ class PeepholeOptimizeDictIdiomsImpl { if (creation_node->kind() == prim::DictConstruct) { auto key = getitem_node->input(1); auto value = getValueFromDict(creation_node, key); - if (value != c10::nullopt) { + if (value != std::nullopt) { getitem_node->output()->replaceAllUsesWith(*value); return true; } diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index 9c106e13edf1f8..f644fe4f1de1c8 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -21,7 +21,7 @@ static std::optional normalizeIndex(int64_t index, size_t len) { if (index >= 0 && index < static_cast(len)) { return index; } else { - return c10::nullopt; + return std::nullopt; } } @@ -136,7 +136,7 @@ struct ListLenRefiner { return maybe_refinement->second; } } - return c10::nullopt; + return std::nullopt; } std::shared_ptr graph_; @@ -199,8 +199,8 @@ struct PeepholeOptimizeListIdiomsImpl { auto step_val = toIValue(slice_node->input(3)); // All args must be constant to apply this optimization. - if (start_val == c10::nullopt || end_val == c10::nullopt || - step_val == c10::nullopt) { + if (start_val == std::nullopt || end_val == std::nullopt || + step_val == std::nullopt) { return false; } diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index 8a74ec01086a58..7eea68eb106546 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -325,7 +325,7 @@ std::optional getClampScalarInputUse(Value* v) { } } } - return c10::nullopt; + return std::nullopt; } void cloneMethod( @@ -503,7 +503,7 @@ std::optional> getFixedQParams(Node* n) { if (isAtenFunc(n, fixed_qparam_funcs)) { return _fixed_qparams_map.at(n->kind()); } - return c10::nullopt; + return std::nullopt; } bool userDefinedCallFunction(Node* n) { @@ -534,13 +534,13 @@ bool nodeQuantizable(Node* n, QuantType quant_type) { bool useQuantizable(const Use& use, QuantType quant_type) { if (quant_type == QuantType::STATIC) { for (const auto& func_input : _observe_inputs_aten_func) { - if (matchAtenFuncToUse(use, func_input.func_name, c10::nullopt)) { + if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) { return use.offset == static_cast(func_input.arg_index); } } for (const auto& func_input : _observe_inputs_call_func) { - if (matchCallFuncToUse(use, func_input.func_name, c10::nullopt)) { + if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) { return use.offset == static_cast(func_input.arg_index); } } @@ -653,7 +653,7 @@ std::optional getInvokedModuleOpt( if (m.attr(p).isModule()) { m = m.attr(p).toModule(); } else { - return c10::nullopt; + return std::nullopt; } } return m; @@ -691,7 +691,7 @@ std::optional getModuleName(Value* value) { if (type && type->name()) { return removeTorchMangle(type->name()->qualifiedName()); } - return c10::nullopt; + return std::nullopt; } static bool is_module( diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h index 680e3c7ca43d52..21efbff7aa6941 100644 --- a/torch/csrc/jit/passes/quantization/helper.h +++ b/torch/csrc/jit/passes/quantization/helper.h @@ -150,7 +150,7 @@ TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self); // Given an CallMethod node, get the module instance corresponding // to the instance Value if the instance is a module, otherwise return -// c10::nullopt +// std::nullopt std::optional getInvokedModuleOpt( const Module& module, Node* n, diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index 145448210958ac..f906efacceca7b 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -49,7 +49,7 @@ void fillQConfigMap( const QConfigDict& qconfig_dict, ModuleQConfigMap& map, const std::string& key = "", - const std::optional& parent_qconfig = c10::nullopt) { + const std::optional& parent_qconfig = std::nullopt) { std::optional qconfig; if (qconfig_dict.find(key) != qconfig_dict.end()) { GRAPH_DEBUG("Got module config for key:", key); @@ -1414,7 +1414,7 @@ InsertObserversHelper::insertObserversFor( if (!isObserved(v, block_observed_values)) { block_output_observers.emplace_back(getObserverFor(v)); } else { - block_output_observers.emplace_back(c10::nullopt); + block_output_observers.emplace_back(std::nullopt); } } } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 92fb2fc79bcc91..3d24834261d2a0 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -234,7 +234,7 @@ std::optional findObserverName(Value* v) { return module_instance->node()->s(attr::name); } } - return c10::nullopt; + return std::nullopt; } bool isPlaceholderObserver(Value* observer) { @@ -268,7 +268,7 @@ std::optional getEmbeddingBagObsName( auto op_name = observer_module.attr("custom_op").toStringRef(); return isPlaceholderObserver(observer) ? std::move(op_name) : ""; } - return c10::nullopt; + return std::nullopt; } bool isEmbeddingBagOp( @@ -792,7 +792,7 @@ class InsertQuantDeQuantHelper { const std::vector& inputs, bool is_scalar = false, const std::optional>& qparams_opt = - c10::nullopt); + std::nullopt); bool isQuantized(Value* v) { return quantized_values_.count(v) != 0; @@ -1269,7 +1269,7 @@ std::optional> getDequantizedInputs(Value* output) { return inputs; } } - return c10::nullopt; + return std::nullopt; } void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) { diff --git a/torch/csrc/jit/passes/remove_mutation.h b/torch/csrc/jit/passes/remove_mutation.h index be8fc12b11f3d7..1242555358f771 100644 --- a/torch/csrc/jit/passes/remove_mutation.h +++ b/torch/csrc/jit/passes/remove_mutation.h @@ -11,7 +11,7 @@ namespace jit { struct TORCH_API MutationRemover { MutationRemover( std::shared_ptr graph, - std::optional> mutation_filter = c10::nullopt) + std::optional> mutation_filter = std::nullopt) : mutation_filter_(mutation_filter), aliasDb_(nullptr), graph_(std::move(graph)) {} @@ -71,7 +71,7 @@ TORCH_API bool RemoveListMutation(const std::shared_ptr& graph); // return true if graph is modified TORCH_API bool RemoveTensorMutation( const std::shared_ptr& graph, - std::optional> mutation_filter = c10::nullopt); + std::optional> mutation_filter = std::nullopt); // Replaces in-place aten activation ops with their functional equivalence TORCH_API bool InplaceToFunctionalActivation( diff --git a/torch/csrc/jit/passes/replacement_of_old_operators.cpp b/torch/csrc/jit/passes/replacement_of_old_operators.cpp index 38255ad1418771..2d3b3a2aba7fc5 100644 --- a/torch/csrc/jit/passes/replacement_of_old_operators.cpp +++ b/torch/csrc/jit/passes/replacement_of_old_operators.cpp @@ -30,7 +30,7 @@ struct OldOpsReplacerWithUpgraders { Node* node = graph_it.next(); while (node) { // load the schema name for this op - std::optional schema_name = c10::nullopt; + std::optional schema_name = std::nullopt; if (auto op_schema = node->maybeSchema()) { schema_name = getFullSchemaName(*op_schema); } else { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index abc7bb6411dbae..7290e1936128c7 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -151,7 +151,7 @@ bool containsTensorType(const TypePtr& t) { } // for each node in the schema with type Tensor, extract the T type -// returns c10::nullopt if any Tensor in the schema does not have a known +// returns std::nullopt if any Tensor in the schema does not have a known // shape ignores non-tensor in the list of inputs std::optional> gatherTensorTypes( Node* node, @@ -160,26 +160,26 @@ std::optional> gatherTensorTypes( auto schema_opt = node->maybeSchema(); if (!schema_opt) { - return c10::nullopt; + return std::nullopt; } auto& schema = *schema_opt; auto& args = schema.arguments(); // can't handle varargs primitives because we don't know what should be a // Tensor if (schema.is_vararg()) { - return c10::nullopt; + return std::nullopt; } for (const auto i : c10::irange(args.size())) { if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) { - return c10::nullopt; + return std::nullopt; } else if (args[i].type()->isSubtypeOf(*TensorType::get())) { if (auto type = node->input(i)->type()->cast()) { if (complete && !type->isComplete()) { - return c10::nullopt; + return std::nullopt; } tensor_types.push_back(type); } else { - return c10::nullopt; + return std::nullopt; } } else /* non-tensor type */ { continue; @@ -217,7 +217,7 @@ std::optional getPromotedTypeForArithmeticOp(Node* node) { auto dtt = node->inputs()[i]->type()->expect(); auto inputDtype = dtt->scalarType(); if (!dtt || !inputDtype) { - return c10::nullopt; + return std::nullopt; } if (dtt->dim() && *dtt->dim() > 0) { dimmed = unionScalarTypes(dimmed, *inputDtype); @@ -552,7 +552,7 @@ class ShapePropagator : public PropertyPropBase { tryScalarTypeFromJitType(*input_base_type); if (auto grad_index = node->schema().argumentIndexWithName("dtype")) { auto inp = toIValue(node->inputs().at(*grad_index)); - if (inp == c10::nullopt) { + if (inp == std::nullopt) { return; } else if (!inp->isNone()) { default_type = inp->toScalarType(); @@ -562,14 +562,14 @@ class ShapePropagator : public PropertyPropBase { at::Device default_device = at::kCPU; if (auto device_index = node->schema().argumentIndexWithName("device")) { auto inp = toIValue(node->inputs().at(*device_index)); - if (inp == c10::nullopt) { + if (inp == std::nullopt) { return; } else if (!inp->isNone()) { default_device = inp->toDevice(); } } node->output()->setType(TensorType::create( - default_type, default_device, dims, /*requires_grad=*/c10::nullopt)); + default_type, default_device, dims, /*requires_grad=*/std::nullopt)); } // returns whether any such values were found @@ -612,10 +612,10 @@ class ShapePropagator : public PropertyPropBase { if (typ->isSubtypeOf(*IntType::get()) || typ->isSubtypeOf(*BoolType::get())) { node->output()->setType(TensorType::create( - at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); + at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt)); } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) { node->output()->setType(TensorType::create( - at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); + at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt)); } return; } @@ -750,7 +750,7 @@ class ShapePropagator : public PropertyPropBase { if (input_node->kind() == prim::ListConstruct) { return input_node->inputs().size(); } - return c10::nullopt; + return std::nullopt; } // is it ok to try to run the op @@ -778,7 +778,7 @@ class ShapePropagator : public PropertyPropBase { auto max_dims = any_type->dim(); for (auto& type : tensor_types) { if (!max_dims || !type->dim()) { - max_dims = c10::nullopt; + max_dims = std::nullopt; } else { max_dims = std::max(*max_dims, *type->dim()); } @@ -787,7 +787,7 @@ class ShapePropagator : public PropertyPropBase { t, any_type->device(), max_dims, - /*requires_grad=*/c10::nullopt); + /*requires_grad=*/std::nullopt); }; using type_vec_t = std::vector; @@ -1245,7 +1245,7 @@ class ShapePropagator : public PropertyPropBase { int64_t num_reduced_dim = 0, bool upcast_integer = false, std::optional opt_dtype = - c10::nullopt) -> type_vec_t { + std::nullopt) -> type_vec_t { if (auto type = node->input(0)->type()->cast()) { if (!type->scalarType() || !type->dim()) { return {}; @@ -1418,7 +1418,7 @@ class ShapePropagator : public PropertyPropBase { : maybe_dtype_option->toScalarType()); return {TensorType::create( - dtype, device, dim, /*requires_grad=*/c10::nullopt)}; + dtype, device, dim, /*requires_grad=*/std::nullopt)}; }; static const auto factory_like_with_ndim = [](Node* node, @@ -1448,7 +1448,7 @@ class ShapePropagator : public PropertyPropBase { } return {TensorType::create( - in_type, in_dev, dim, /*requires_grad=*/c10::nullopt)}; + in_type, in_dev, dim, /*requires_grad=*/std::nullopt)}; }; // Requirements: @@ -1748,7 +1748,7 @@ class ShapePropagator : public PropertyPropBase { if (auto dtype_index = node->schema().argumentIndexWithName("dtype")) { auto inp = toIValue(node->inputs().at(*dtype_index)); - if (inp == c10::nullopt) { + if (inp == std::nullopt) { return nullptr; } if (!inp->isNone()) { @@ -1758,7 +1758,7 @@ class ShapePropagator : public PropertyPropBase { if (auto device_index = node->schema().argumentIndexWithName("device")) { auto inp = toIValue(node->inputs().at(*device_index)); - if (inp == c10::nullopt) { + if (inp == std::nullopt) { return nullptr; } if (!inp->isNone()) { @@ -1769,7 +1769,7 @@ class ShapePropagator : public PropertyPropBase { default_type, default_device, type->dim(), - /*requires_grad=*/c10::nullopt)); + /*requires_grad=*/std::nullopt)); } } return nullptr; diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index 951c093cefe55a..6ac9576a8e2bc5 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -61,7 +61,7 @@ namespace jit { // %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2) // // x.view([5, y.size(0), inp]) -// will have inputs equal to [5, SS(-1), c10::nullopt] +// will have inputs equal to [5, SS(-1), std::nullopt] struct ShapeArg : public std:: @@ -73,17 +73,17 @@ struct ShapeArg } ShapeArg(int64_t int_value) { - this->first = c10::nullopt; + this->first = std::nullopt; this->second = int_value; } ShapeArg(c10::ShapeSymbol ss) { if (ss.is_static()) { - this->first = c10::nullopt; + this->first = std::nullopt; this->second = ss.value(); } else { this->first = ss; - this->second = c10::nullopt; + this->second = std::nullopt; } } @@ -97,8 +97,8 @@ struct ShapeArg private: ShapeArg() { - this->first = c10::nullopt; - this->second = c10::nullopt; + this->first = std::nullopt; + this->second = std::nullopt; } }; @@ -215,7 +215,7 @@ std::optional normIndex(int64_t index, size_t len) { if (index >= 0 && index < static_cast(len)) { return index; } else { - return c10::nullopt; + return std::nullopt; } } @@ -608,7 +608,7 @@ struct SymbolicShapeOpAnalyzer { std::optional> run( std::vector& inputs) { if (!shape_compute_graph_) { - return c10::nullopt; + return std::nullopt; } inputs_ = inputs; substituteConstantInputs(); @@ -788,7 +788,7 @@ c10::SymbolicShape combine_bounds( c10::SymbolicShape& upper_bound) { // TODO: At some point we might want to add support for dynamic dims TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank()); - if (lower_bound.rank() == c10::nullopt) { + if (lower_bound.rank() == std::nullopt) { return c10::SymbolicShape(); } std::vector merged_shapes; @@ -837,14 +837,14 @@ struct SymbolicShapeGraphAnalyzer { return use.user->kind() == aten::cat; })) { GRAPH_DEBUG("Non cat list use ", getHeader(curr)); - return c10::nullopt; + return std::nullopt; } continue; } if (!partial_evaluated_graphs.count(curr)) { GRAPH_DEBUG("No graph ", getHeader(curr)); - return c10::nullopt; + return std::nullopt; } auto outputs = curr->outputs(); @@ -852,13 +852,13 @@ struct SymbolicShapeGraphAnalyzer { auto tt = v->type()->cast(); if (!tt) { GRAPH_DEBUG("Non tensor node", getHeader(curr)); - return c10::nullopt; + return std::nullopt; } auto symbolic_sizes = tt->symbolic_sizes(); // TODO: dont require # of dimensions of tensors set ? if (!symbolic_sizes.rank()) { GRAPH_DEBUG("No rank on output ", getHeader(curr)); - return c10::nullopt; + return std::nullopt; } } auto partial_eval_graph = partial_evaluated_graphs[curr]; @@ -1133,11 +1133,11 @@ calculateSymbolicShapesOnOp( const FunctionSchema* schema, const std::vector& inputs) { auto bounded_graphs = boundedGraphsForSchema(*schema); - auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt; - if (!has_shape_compute && bounded_graphs == c10::nullopt) { + auto has_shape_compute = shapeComputeGraphForSchema(*schema) != std::nullopt; + if (!has_shape_compute && bounded_graphs == std::nullopt) { // Avoid doing all this work for functions that don't have a // supported schema - return c10::nullopt; + return std::nullopt; } if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) { @@ -1172,7 +1172,7 @@ calculateSymbolicShapesOnOp( cache_shape_function(schema, inputs, merged_res); return merged_res; } - return c10::nullopt; + return std::nullopt; } auto op_analyzer = SymbolicShapeOpAnalyzer(schema); diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.cpp b/torch/csrc/jit/passes/symbolic_shape_cache.cpp index 4a742b3f5f6351..d01d11983a622c 100644 --- a/torch/csrc/jit/passes/symbolic_shape_cache.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_cache.cpp @@ -120,7 +120,7 @@ get_cached_shape_function( get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false); auto cached_ret_vec = shapeCache.Get(cache_key); if (cached_ret_vec == nullptr) { - return c10::nullopt; + return std::nullopt; } // Decanonicalize the return values auto inverse_ss_map = std::unordered_map(); @@ -148,7 +148,7 @@ void CanonicalizedSymbolicShape::init( std::unordered_map& ss_map) { auto sizes = orig_shape.sizes(); if (!sizes) { - values_ = c10::nullopt; + values_ = std::nullopt; return; } values_ = std::vector(); diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp index 9c213f2480d51d..3cf23732a9ad65 100644 --- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp @@ -190,7 +190,7 @@ TryGeneralizeInputDimensionsToSymbolicShapes( } auto tt = v->type()->expectRef(); if (!tt.sizes().isComplete() || !tt.strides().isComplete()) { - return c10::nullopt; + return std::nullopt; } input_striding.push_back(summarizeInputStrides(tt)); std::vector shape_vec = *tt.symbolic_sizes().sizes(); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index c9b9b974600dc4..684f47f4efb93a 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -782,7 +782,7 @@ class TensorExprFuser { std::optional tryMerge(Node* fusion_group, Node* to_merge) { if (!canMerge(fusion_group, to_merge)) { - return c10::nullopt; + return std::nullopt; } std::vector nodes_to_merge = {to_merge}; @@ -799,7 +799,7 @@ class TensorExprFuser { GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) { GRAPH_UPDATE("Failed to move because of AliasDB checks!"); - return c10::nullopt; + return std::nullopt; } move_point = n; } diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 4c081200715a71..6082058952ce9e 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -196,7 +196,7 @@ std::optional toIValueProp(const Value* v) { genericList.push_back(*elem); } else { // One of the list elements isn't constant. - return c10::nullopt; + return std::nullopt; } } @@ -213,7 +213,7 @@ std::optional toIValueProp(const Value* v) { return IValue( fmap(genericList, [](const IValue& v) { return v.toTensor(); })); } else { - return c10::nullopt; + return std::nullopt; } } @@ -222,7 +222,7 @@ std::optional toIValueProp(const Value* v) { return maybe_stack->at(0); } } - return c10::nullopt; + return std::nullopt; } // batch_norm and instance_norm have incorrect annotations, because diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h index da5584f9d4bd35..1d2292fe90c5ba 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.h +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -2,12 +2,12 @@ #include #include -#include #include #include #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 377621c04b6dbf..f4dfc4ce99c940 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -429,7 +429,7 @@ Node* createSingletonSubgraphAndUpdateAliasing( Symbol subgraphKind, AliasDb& db) { return executeSubgraphMergeAndUpdateAliasing( - to_merge, c10::nullopt, db, [&]() { + to_merge, std::nullopt, db, [&]() { return createSingletonSubgraph(to_merge, subgraphKind); }); } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1bfc6c94a707f4..862aaba7d7dc14 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -157,7 +157,7 @@ std::optional toTypeInferredIValueOptional(py::handle input) { try { return toTypeInferredIValue(input); } catch (const c10::Error& e) { - return c10::nullopt; + return std::nullopt; } } } // anonymous namespace @@ -219,7 +219,7 @@ void initJITBindings(PyObject* module) { "_jit_shape_compute_graph_for_node", [](Node* n) -> std::optional> { if (!n->maybeSchema()) { - return c10::nullopt; + return std::nullopt; } return shapeComputeGraphForSchema(n->schema()); }) @@ -227,7 +227,7 @@ void initJITBindings(PyObject* module) { "_jit_decomposition_graph_for_node", [](Node* n) -> std::optional> { if (!n->maybeSchema()) { - return c10::nullopt; + return std::nullopt; } return GetDecomposition(n->schema()); }) @@ -1165,7 +1165,7 @@ void initJITBindings(PyObject* module) { c10::kCPU, std::vector{1}, std::vector{1}, - c10::nullopt)); + std::nullopt)); } } }) @@ -1680,7 +1680,7 @@ void initJITBindings(PyObject* module) { [op, symbol, allow_numbers_as_tensors]( c10::DispatchKey dk_, py::args args, py::kwargs kwargs) { std::optional dk = - c10::make_optional(dk_); + std::make_optional(dk_); ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, /*is_overload*/ true, dk); diff --git a/torch/csrc/jit/python/module_python.h b/torch/csrc/jit/python/module_python.h index 5c7fbbb42d6cfc..b1ddf6f37c6786 100644 --- a/torch/csrc/jit/python/module_python.h +++ b/torch/csrc/jit/python/module_python.h @@ -14,7 +14,7 @@ inline std::optional as_module(py::handle obj) { if (py::isinstance(obj, ScriptModule)) { return py::cast(obj.attr("_c")); } - return c10::nullopt; + return std::nullopt; } inline std::optional as_object(py::handle obj) { @@ -29,7 +29,7 @@ inline std::optional as_object(py::handle obj) { if (py::isinstance(obj, RecursiveScriptClass)) { return py::cast(obj.attr("_c")); } - return c10::nullopt; + return std::nullopt; } } // namespace torch::jit diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index a731640223c096..2dbcfee423ae7a 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -754,7 +754,7 @@ std::pair, Stack> getOpWithStack( std::shared_ptr op = operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( - op->schema(), std::move(args), kwargs, c10::nullopt); + op->schema(), std::move(args), kwargs, std::nullopt); return std::make_pair(std::move(op), std::move(stack)); } else { @@ -762,7 +762,7 @@ std::pair, Stack> getOpWithStack( std::shared_ptr found_op = nullptr; for (const auto& op : operations) { try { - stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt); + stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt); found_op = op; break; } catch (schema_match_error& error) { diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 23fda5b0d784ec..cd8a7335167d4a 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -36,8 +36,8 @@ #include #endif #include -#include #include +#include #include #include @@ -62,7 +62,7 @@ void clear_registered_instances(void* ptr); TORCH_PYTHON_API IValue toIValue( py::handle obj, const TypePtr& type, - std::optional N = c10::nullopt); + std::optional N = std::nullopt); TORCH_PYTHON_API py::object toPyObject(IValue ivalue); @@ -111,7 +111,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper explicit PythonFutureWrapper( c10::intrusive_ptr fut, - std::optional unwrap_func = c10::nullopt) + std::optional unwrap_func = std::nullopt) : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {} explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete; @@ -1205,7 +1205,7 @@ inline std::optional maybeTorchFunctionDispatch( /*module_name=*/qualname.prefix().c_str())); } - return c10::nullopt; + return std::nullopt; } inline py::object invokeScriptFunctionFromPython( @@ -1219,7 +1219,7 @@ inline py::object invokeScriptFunctionFromPython( callee, args, kwargs, - /*self=*/c10::nullopt, + /*self=*/std::nullopt, [&](Graph& graph, const MatchedSchema& match) { return graph.insertFunctionCall(&callee, match); }); @@ -1255,7 +1255,7 @@ TORCH_PYTHON_API py::object invokeOperatorFromPython( const std::vector>& operations, py::args args, const py::kwargs& kwargs, - std::optional dk = c10::nullopt); + std::optional dk = std::nullopt); TORCH_PYTHON_API std::optional _maybe_handle_torch_function( const std::string& ns, @@ -1276,6 +1276,6 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet( py::args args, const py::kwargs& kwargs, bool is_overload, - std::optional dk = c10::nullopt); + std::optional dk = std::nullopt); } // namespace torch::jit diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 79957999f543dc..c80208b9d00df8 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -138,17 +138,17 @@ std::optional ConcretePythonOp::autogradFunction() const { auto r = py::getattr(obj, "__self__", py::none()); if (r.is_none()) - return c10::nullopt; + return std::nullopt; auto apply = py::getattr(r, "apply", py::none()); if (apply.is_none()) - return c10::nullopt; + return std::nullopt; auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE); if (PyErr_Occurred()) throw py::error_already_set(); if (c) - return c10::nullopt; + return std::nullopt; return THPObjectPtr(r.release().ptr()); } diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index 4cdc8e430b9a81..6d0bf1afc3b06f 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -31,7 +31,7 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { return torch::jit::tryToInferType(py_obj_); } - IValue toIValue(const TypePtr& type, std::optional N = c10::nullopt) + IValue toIValue(const TypePtr& type, std::optional N = std::nullopt) override { pybind11::gil_scoped_acquire ag; return torch::jit::toIValue(py_obj_, type, N); diff --git a/torch/csrc/jit/python/python_list.h b/torch/csrc/jit/python/python_list.h index b5bb88b3aeb20d..f73cb5048529bd 100644 --- a/torch/csrc/jit/python/python_list.h +++ b/torch/csrc/jit/python/python_list.h @@ -4,10 +4,10 @@ #include #include #include -#include #include #include #include +#include #include namespace torch::jit { @@ -175,7 +175,7 @@ class ScriptList final { // Remove and return the element at the specified index from the list. If no // index is passed, the last element is removed and returned. - IValue pop(std::optional idx = c10::nullopt) { + IValue pop(std::optional idx = std::nullopt) { IValue ret; if (idx) { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index d6f014759c05e9..c5d48f5cbe7474 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -28,7 +28,7 @@ std::optional as_function(const py::object& obj) { if (py::isinstance(obj)) { return py::cast(obj); } - return c10::nullopt; + return std::nullopt; } FunctionSchema PythonValue::getSchema( @@ -66,8 +66,8 @@ FunctionSchema PythonValue::getSchema( args.emplace_back( /*name=*/*names_it, /*type=*/TensorType::get(), - /*N=*/c10::nullopt, - /*default_value=*/c10::nullopt, + /*N=*/std::nullopt, + /*default_value=*/std::nullopt, /*kwarg_only=*/false); } @@ -95,8 +95,8 @@ FunctionSchema PythonValue::getSchema( args.emplace_back( /*name=*/*names_it, /*type=*/std::move(*types_it), - /*N=*/c10::nullopt, - /*default_value=*/c10::nullopt, + /*N=*/std::nullopt, + /*default_value=*/std::nullopt, /*kwarg_only=*/false); } rets.push_back(Argument("0", std::move(ret_type), {}, {}, false)); @@ -240,10 +240,10 @@ std::shared_ptr CUDAPythonModuleValue::attr( // these APIs. if (field == "current_device" || field == "set_device") { return std::make_shared( - Symbol::cuda("_" + field), c10::nullopt); + Symbol::cuda("_" + field), std::nullopt); } else { return std::make_shared( - Symbol::cuda(field), c10::nullopt); + Symbol::cuda(field), std::nullopt); } } @@ -673,7 +673,7 @@ std::shared_ptr ModuleValue::tryGetAttr( if (const auto fnAttr = concreteType_->findFunctionAttribute(field)) { return std::make_shared(*fnAttr); } else if (const auto builtin = concreteType_->findBuiltinFunction(field)) { - return std::make_shared(*builtin, /*self=*/c10::nullopt); + return std::make_shared(*builtin, /*self=*/std::nullopt); } // 5. Check if it's an attribute of the original Python class that this @@ -1263,7 +1263,7 @@ std::shared_ptr toSugaredValue( py::module::import("torch.jit._builtins").attr("_find_builtin")(obj); if (!builtin_name.is_none()) { return std::make_shared( - Symbol::fromQualString(py::str(builtin_name)), c10::nullopt); + Symbol::fromQualString(py::str(builtin_name)), std::nullopt); } if (py::cast(py::module::import("torch._jit_internal") diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index cb397796c9f55e..508d95c8c538d0 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -32,7 +32,7 @@ std::optional as_function(const py::object& obj); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue( py::object the_self, - std::optional rcb = c10::nullopt, + std::optional rcb = std::nullopt, Value* module_self = nullptr) : self(std::move(the_self)), rcb(std::move(rcb)), diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 50d18b908107ee..0cd93887471e31 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -14,7 +14,7 @@ namespace torch::jit { std::optional maybeConvertToString(const py::object& obj) { if (obj.is_none()) { - return c10::nullopt; + return std::nullopt; } std::stringstream ss; ss << py::str(obj); @@ -180,7 +180,7 @@ void initTreeViewBindings(PyObject* module) { return std::optional(property.setter().get().name()); } - return std::optional(c10::nullopt); + return std::optional(std::nullopt); }); py::class_(m, "ClassDef") diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index c46762a88615bb..565f0b16363855 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -220,7 +220,7 @@ std::optional tryCalculateDefaultParam( return toIValue(def_value, arg.type()); } } catch (...) { - return c10::nullopt; + return std::nullopt; } } @@ -702,13 +702,13 @@ void pyCompilationUnitDefine( const ResolutionCallback* rcb, const uint32_t _frames_up) { if (rcb && *rcb) { - cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr); + cu.define(std::nullopt, src, pythonResolver(*rcb), nullptr); } else { py::object py_default_rcb = py::module::import("torch._jit_internal") .attr("createResolutionCallbackFromFrame")(_frames_up); auto default_rcb = py_default_rcb.cast(); - cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr); + cu.define(std::nullopt, src, pythonResolver(default_rcb), nullptr); } } @@ -1315,7 +1315,7 @@ void initJitScriptBindings(PyObject* module) { "find_method", [](mobile::Module& m, const std::string& method_name) { auto method = m.find_method(method_name); - return method != c10::nullopt; + return method != std::nullopt; }, py::arg("method_name")) .def( @@ -1372,7 +1372,7 @@ void initJitScriptBindings(PyObject* module) { return std::optional( StrongFunctionPtr(std::move(self), fn)); } else { - return std::optional(c10::nullopt); + return std::optional(std::nullopt); } }) .def( @@ -2124,7 +2124,7 @@ void initJitScriptBindings(PyObject* module) { m.def( "_get_graph_executor_optimize", - [](std::optional new_setting = c10::nullopt) { + [](std::optional new_setting = std::nullopt) { bool old_value = getGraphExecutorOptimize(); if (new_setting) { setGraphExecutorOptimize(*new_setting); diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 3987521f658f97..047a35e417fff8 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -134,11 +134,11 @@ static std::optional> build_script_grad( auto graph = node->owningGraph(); auto maybe_schema = node->maybeSchema(); if (!maybe_schema) { - return c10::nullopt; + return std::nullopt; } auto compiled_graphs = gradientInfoForSchema(*maybe_schema); if (!compiled_graphs) { - return c10::nullopt; + return std::nullopt; } // Use forward graph to replace node in grad_desc.f value_list new_outputs; diff --git a/torch/csrc/jit/runtime/decomposition_registry.cpp b/torch/csrc/jit/runtime/decomposition_registry.cpp index de205ed834c3bc..989a48bf06ab22 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry.cpp @@ -63,7 +63,7 @@ void loadDecompositionFunctions() { [&](const std::string& name) -> std::shared_ptr { return src; }, 1); compilation_unit->define( - c10::nullopt, GetSerializedDecompositions(), resolver, nullptr); + std::nullopt, GetSerializedDecompositions(), resolver, nullptr); loadModule(*compilation_unit); } @@ -117,7 +117,7 @@ std::optional> GetDecomposition( } GRAPH_DEBUG("Could not find schema: ", schema); - return c10::nullopt; + return std::nullopt; } std::optional GetDecompositionFunction( @@ -127,7 +127,7 @@ std::optional GetDecompositionFunction( GRAPH_DEBUG("Trying to find schema: ", schema); if (cache_it == schema_to_function.end()) { GRAPH_DEBUG("Could not find schema: ", schema); - return c10::nullopt; + return std::nullopt; } auto& func = toGraphFunction(*cache_it->second); // Simple Executor: diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index fce8d4a02e66c1..971e45e818ca6d 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -87,7 +87,7 @@ struct TORCH_API GraphExecutor { // current global fusion strategy settings. const ExecutionPlan& getPlanFor( Stack& inputs, - std::optional remaining_bailout_depth = c10::nullopt); + std::optional remaining_bailout_depth = std::nullopt); GraphExecutorState getDebugState(); void debugFlushCompilationCache(); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index 22a563f00be289..70069ac1907b0f 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -78,7 +78,7 @@ struct GraphExecutorImplBase { virtual const ExecutionPlan& getPlanFor( Stack& stack, - std::optional remaining_bailout_depth = c10::nullopt) = 0; + std::optional remaining_bailout_depth = std::nullopt) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 18231173dd70e0..0f6eb900e361df 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -169,7 +169,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } void enterFrame(const Code& code, size_t base_pointer) { - frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt}); + frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt}); registers.resize(registers.size() + code.pImpl->register_size_); } @@ -181,7 +181,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { void callFunction( Function& f, Stack& stack, - std::optional bailOut = c10::nullopt, + std::optional bailOut = std::nullopt, bool next = true) { bool newFrame = f.call(stack, bailOut, [&](const Code& code) { enterFrame(code, stack.size() - code.num_inputs()); @@ -1244,7 +1244,7 @@ void InterpreterContinuation::operator()() { auto prev_dist_id = DistAutogradContainer::currentContextId(); DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_); #endif - if (tls_state_ != c10::nullopt) { + if (tls_state_ != std::nullopt) { at::ThreadLocalStateGuard g(*tls_state_); state.runAsync(stack); } else { diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index a28b1eb93526b5..ffafd3ab096a9b 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include #include #include @@ -124,7 +124,7 @@ struct InterpreterContinuation { InterpreterState state_, Stack stack_, int64_t dist_autograd_context_id = 0, - std::optional tls_state = c10::nullopt) + std::optional tls_state = std::nullopt) : state(std::move(state_)), stack(std::move(stack_)), tls_state_(std::move(tls_state)) @@ -140,7 +140,7 @@ struct InterpreterContinuation { private: InterpreterState state; Stack stack; - std::optional tls_state_ = c10::nullopt; + std::optional tls_state_ = std::nullopt; #ifdef USE_DISTRIBUTED int64_t dist_autograd_context_id_; #endif diff --git a/torch/csrc/jit/runtime/jit_exception.h b/torch/csrc/jit/runtime/jit_exception.h index 34c3ebd6fca849..cb4f572a8bd3c0 100644 --- a/torch/csrc/jit/runtime/jit_exception.h +++ b/torch/csrc/jit/runtime/jit_exception.h @@ -2,8 +2,8 @@ #include -#include #include +#include #include namespace torch::jit { @@ -11,8 +11,8 @@ namespace torch::jit { struct TORCH_API JITException : public std::runtime_error { explicit JITException( const std::string& msg, - std::optional python_class_name = c10::nullopt, - std::optional original_msg = c10::nullopt); + std::optional python_class_name = std::nullopt, + std::optional original_msg = std::nullopt); std::optional getPythonClassName() const { return python_class_name_; diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index dbc2638457c056..2e609f18ecc074 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -322,7 +322,7 @@ std::optional OperatorGenerator( torch::detail::SelectiveStr schema_str, Func&& op, AliasAnalysisKind alias_analysis) { - return c10::nullopt; + return std::nullopt; } template diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 48c7a1959ab220..54ec8e8441fa7e 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -37,6 +36,7 @@ #include #include #include +#include C10_DEFINE_bool( torch_jit_enable_new_executor, @@ -118,11 +118,11 @@ static FusionStrategy getInitialStrategy() { } // defer initial value so that we can load in gflags -static std::optional fusion_strategy = c10::nullopt; +static std::optional fusion_strategy = std::nullopt; FusionStrategy getFusionStrategy() { std::lock_guard guard(fusion_strategy_lock); - if (fusion_strategy == c10::nullopt) { + if (fusion_strategy == std::nullopt) { fusion_strategy = getInitialStrategy(); } return *fusion_strategy; @@ -130,7 +130,7 @@ FusionStrategy getFusionStrategy() { FusionStrategy setFusionStrategy(FusionStrategy& strategy) { std::lock_guard guard(fusion_strategy_lock); - if (fusion_strategy == c10::nullopt) { + if (fusion_strategy == std::nullopt) { fusion_strategy = getInitialStrategy(); } FusionStrategy old_strategy = *fusion_strategy; @@ -320,7 +320,7 @@ static bool guardDifferentiableGraph(Node* dnode) { // we inline the differentiable graph as a fallback // ideally we would set this up for re-profiling UpdateDifferentiableGraphRequiresGrad( - dnode->g(attr::Subgraph), c10::nullopt); + dnode->g(attr::Subgraph), std::nullopt); SubgraphUtils::unmergeSubgraph(dnode); return false; } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 3386bc3e4a4918..ebdc5ba205cd56 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -878,6 +878,6 @@ struct OperatorGeneratorArgs { TORCH_API at::Generator make_generator_for_device( c10::Device device, - std::optional seed = c10::nullopt); + std::optional seed = std::nullopt); } // namespace torch::jit diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index bb9c08465c0ae9..f6eccede28bab1 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -8,6 +7,7 @@ #include #include #include +#include #include #include @@ -1807,7 +1807,7 @@ static const std::vector stringOpGenArgs{ std::string::size_type prev_pos = 0; std::string::size_type pos = 0; c10::List splits; - if (ivalue == c10::nullopt) { + if (ivalue == std::nullopt) { // if separator is not specified, // a different splitting algorithm is applied as Python splits = splitNoneSeparator(string); @@ -2463,8 +2463,8 @@ static const std::vector opGenArgs1{ // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool copy; pop(stack, self, non_blocking, copy); - std::optional device = c10::nullopt; - std::optional scalarType = c10::nullopt; + std::optional device = std::nullopt; + std::optional scalarType = std::nullopt; push( stack, to_dispatch(self, device, scalarType, non_blocking, copy)); }, diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 4359b852b6a38a..035a5d35c4630f 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -430,13 +430,13 @@ at::Tensor interpolate( std::optional align_corners, std::optional recompute_scale_factor) { if ((mode == "nearest" || mode == "area")) { - if (align_corners != c10::nullopt) { + if (align_corners != std::nullopt) { throw std::runtime_error( "align_corners option can only be set with the " "interpolating modes: linear | bilinear | bicubic | trilinear"); } } else { - if (align_corners == c10::nullopt) { + if (align_corners == std::nullopt) { TORCH_WARN( "Default upsampling behavior when mode=", mode, @@ -451,7 +451,7 @@ at::Tensor interpolate( double scale_factors_2 = -1.0; double scale_factors_3 = -1.0; - if (!scale_factors.isNone() && recompute_scale_factor == c10::nullopt) { + if (!scale_factors.isNone() && recompute_scale_factor == std::nullopt) { recompute_scale_factor = true; bool warn_recompute_scale_factor = false; @@ -510,7 +510,7 @@ at::Tensor interpolate( return at::upsample_nearest1d( input, _output_size(input, 1, size, scale_factors), - c10::make_optional(scale_factors_1)); + std::make_optional(scale_factors_1)); if (input_dim == dim2d && mode == "nearest") return at::upsample_nearest2d( input, @@ -538,7 +538,7 @@ at::Tensor interpolate( input, _output_size(input, 1, size, scale_factors), *align_corners, - c10::make_optional(scale_factors_1)); + std::make_optional(scale_factors_1)); if (input_dim == dim1d && mode == "bilinear") throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input"); if (input_dim == dim1d && mode == "bicubic") @@ -646,7 +646,7 @@ void upsample_nearest_op(Stack& stack) { pop(stack, input, size, scale_factor_int); IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int); at::Tensor res = interpolate( - input, size, scale_factor_double, "nearest", c10::nullopt, c10::nullopt); + input, size, scale_factor_double, "nearest", std::nullopt, std::nullopt); push(stack, std::move(res)); } @@ -664,7 +664,7 @@ void upsample_op(Stack& stack) { scale_factor_double, mode, align_corners.toOptional(), - c10::nullopt); + std::nullopt); push(stack, std::move(res)); } @@ -675,7 +675,7 @@ void upsample_bilinear_op(Stack& stack) { pop(stack, input, size, scale_factor_int); IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int); at::Tensor res = interpolate( - input, size, scale_factor_double, "bilinear", true, c10::nullopt); + input, size, scale_factor_double, "bilinear", true, std::nullopt); push(stack, std::move(res)); } diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 5b8c70c404ae98..63fdee6de8042c 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -301,9 +301,9 @@ RegisterOperators reg({ at::native::scalar_tensor( scalar_val, typeMetaToScalarType(c10::get_default_dtype()), - c10::nullopt /* layout */, + std::nullopt /* layout */, at::kCPU, - c10::nullopt /* pin_memory*/)) + std::nullopt /* pin_memory*/)) DEFINE_TORCH_TENSOR_OP( int, int64_t, @@ -314,9 +314,9 @@ RegisterOperators reg({ at::native::scalar_tensor( scalar_val, typeMetaToScalarType(c10::get_default_complex_dtype()), - c10::nullopt /* layout */, + std::nullopt /* layout */, at::kCPU, - c10::nullopt /* pin_memory */)) + std::nullopt /* pin_memory */)) // reference python implementation: internal_new_from_data in // tensor_new.cpp diff --git a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp index c1dbbddc6d337a..fd908b48ee043f 100644 --- a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp @@ -1,8 +1,8 @@ #include -#include #include #include +#include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index ffac37efc9b765..86925200b7f463 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -173,7 +173,7 @@ static std::optional tryMerge( Node* to_merge, AliasDb* aliasDb) { if (!canMerge(fusion_group, to_merge, aliasDb)) { - return c10::nullopt; + return std::nullopt; } std::vector nodes_to_merge = {to_merge}; @@ -190,7 +190,7 @@ static std::optional tryMerge( GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) { GRAPH_UPDATE("Failed to move because of AliasDb checks!"); - return c10::nullopt; + return std::nullopt; } move_point = n; } diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 9dc31446d1e1c7..0c989efcad7577 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -320,7 +320,7 @@ std::pair, std::optional> PrepareForStaticModule( const StaticModuleOptions& opts, std::vector sample_inputs) { PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs)); - return std::make_pair(graph, c10::nullopt); + return std::make_pair(graph, std::nullopt); } } // namespace @@ -573,7 +573,7 @@ StaticModule::StaticModule( const auto num_schema_args = schema_->arguments().size(); DCHECK(num_schema_args > 0); if (removeSelfFromGraphInput(graph_)) { - module_ = c10::nullopt; + module_ = std::nullopt; num_inputs_ = num_schema_args - 1; } } @@ -1251,7 +1251,7 @@ bool BlockRunner::fast_check_and_correct_overlap_with( auto& tensor = tensor_ival.toTensor(); if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) { DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node()); - tensor_ival = at::native::clone(tensor, c10::nullopt); + tensor_ival = at::native::clone(tensor, std::nullopt); n.set_outputs_memory_overlap_detected(); return true; } @@ -2218,7 +2218,7 @@ bool ProcessedNode::check_and_correct_overlap_with( auto& tensor = output_ival.toTensor(); if (!checkNoMemoryOverlap(input, tensor)) { DLOG(INFO) << "Detected alias for node: " << PrintNode(node()); - output_ival = at::native::clone(tensor, c10::nullopt); + output_ival = at::native::clone(tensor, std::nullopt); set_outputs_memory_overlap_detected(); return true; } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index b1b8a081c4ce63..35a74c0bac089b 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -75,7 +75,7 @@ static void repeat_out( } // return an empty tensor if one of the repeat dimensions is zero - at::native::resize_(result, target_size, c10::nullopt); + at::native::resize_(result, target_size, std::nullopt); if (zero_tensor) { return; } @@ -101,7 +101,7 @@ at::Tensor& reshape_copy_out( const auto& shape = infer_size ? at::infer_size_dv(proposed_shape, self.numel()) : proposed_shape; - at::native::resize_(out, shape, c10::nullopt); + at::native::resize_(out, shape, std::nullopt); auto self_contig = self.expect_contiguous(); @@ -214,7 +214,7 @@ at::Tensor& to_copy_out( at::native::resize_impl_cpu_( out.unsafeGetTensorImpl(), self.sizes(), self.strides()); } else { - at::native::resize_(out, self.sizes(), c10::nullopt); + at::native::resize_(out, self.sizes(), std::nullopt); } auto is_unsupported_dtype = [](ScalarType t) { #define TORCH_OPS_UNSUPPORTED_TYPE(_, type) \ @@ -233,7 +233,7 @@ at::Tensor& to_copy_out( // expensive. if (self.is_contiguous() && !non_blocking && // Did the user request us to make a copy that isn't contiguous? - (memory_format == c10::nullopt || + (memory_format == std::nullopt || memory_format == c10::MemoryFormat::Preserve || memory_format == c10::MemoryFormat::Contiguous) && // CopyKernel.cpp handles this case specially, so let's not mess @@ -303,7 +303,7 @@ static Tensor& c2_argmin_out( out_dims.push_back(in_dims[i]); next_size *= in_dims[i]; } - at::native::resize_(output, out_dims, c10::nullopt); + at::native::resize_(output, out_dims, std::nullopt); const auto n = in_dims[dim_]; @@ -370,7 +370,7 @@ static at::Tensor& dequantize_copy_out(Tensor& out, const Tensor& self) { if (C10_UNLIKELY(!self.is_quantized())) { // fallback to dequantize_cpu equivalent case: make sure out is at::kFloat DCHECK(out.scalar_type() == kFloat); - return at::native::to_copy_out(out, self, false, false, c10::nullopt); + return at::native::to_copy_out(out, self, false, false, std::nullopt); } return get_qtensorimpl(self)->quantizer()->dequantize_out(out, self); } @@ -658,11 +658,11 @@ REGISTER_OPERATOR_FUNCTOR( out_t, at::cpu::clamp(in0_t, clamp_min, clamp_max), in3_s, - c10::nullopt, - c10::nullopt); + std::nullopt, + std::nullopt); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); auto output_size = in0_t.numel(); @@ -700,7 +700,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); auto output_size = in0_t.numel(); auto min = in1_s.has_value() ? in1_s->toFloat() : -std::numeric_limits::infinity(); @@ -830,7 +830,7 @@ void varStackFastOut( ? std::array{num_inputs, 1} : std::array{1, num_inputs}; - at::native::resize_(out, output_size, c10::nullopt); + at::native::resize_(out, output_size, std::nullopt); AT_DISPATCH_ALL_TYPES(out.scalar_type(), "varStackFastOut", [&]() { auto* out_data = out.mutable_data_ptr(); @@ -952,7 +952,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { at::cpu::threshold_out(out_t, in0_t, 0, 0); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -975,7 +975,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator { at::cpu::tanh_out(out_t, in0_t); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -1036,7 +1036,7 @@ REGISTER_OPERATOR_FUNCTOR( at::cpu::sigmoid_out(out_t, in0_t); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -1048,12 +1048,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { LogAndDumpSchema(n); return nullptr; } - std::optional clamp = c10::nullopt; + std::optional clamp = std::nullopt; if (n->inputs()[1]->node()->kind() == prim::Constant) { auto clamp_d = toIValue(n->inputs()[1])->toOptional(); clamp = clamp_d - ? c10::make_optional(static_cast(clamp_d.value())) - : c10::nullopt; + ? std::make_optional(static_cast(clamp_d.value())) + : std::nullopt; } auto te = clamp ? createLogit() : nullptr; float clamp_value = clamp ? *clamp : 0.0f; @@ -1070,7 +1070,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { at::native::logit_out(in0_t, in1_d, out_t); return; } - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); int64_t nn = in0_t.numel(); float c = clamp_value; te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c}); @@ -1454,7 +1454,7 @@ C10_ALWAYS_INLINE void to_copy_functor_impl( if (memory_format == c10::MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { - memory_format = c10::nullopt; + memory_format = std::nullopt; copy_strides = true; } else { memory_format = self.suggest_memory_format(); @@ -1485,7 +1485,7 @@ C10_ALWAYS_INLINE void to_copy_functor_impl( args->dtype, args->layout, self.device(), - c10::nullopt, + std::nullopt, memory_format); } else { if (has_memory_format) { @@ -1905,7 +1905,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { return [te = createDiv()](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); - std::optional rounding_mode = c10::nullopt; + std::optional rounding_mode = std::nullopt; if (p_node->num_inputs() > 2) { rounding_mode = p_node->Input(2).toOptional(); } @@ -2112,14 +2112,14 @@ REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROp if (p_node->Output(0).isNone()) { p_node->Output(0) = at::native::empty_like( *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - c10::nullopt /* device */, - c10::nullopt /* pin_memory */, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, at::MemoryFormat::Contiguous); } else { at::native::resize_( - p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); + p_node->Output(0).toTensor(), X->sizes(), std::nullopt); } at::Tensor& output = p_node->Output(0).toTensor(); at::native::layer_norm_cpu_out(output, *X, *gamma, *beta, eps, M, N); @@ -2231,12 +2231,12 @@ REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SR p_node->Output(0) = at::native::empty_affine_quantized( {0}, c10::kQUInt8, - c10::nullopt, + std::nullopt, c10::kCPU, false, output_scale, output_zero_point, - c10::nullopt); + std::nullopt); } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); @@ -2277,12 +2277,12 @@ REGISTER_OPERATOR_FUNCTOR( p_node->Output(0) = at::native::empty_affine_quantized( {0}, c10::kQUInt8, - c10::nullopt, + std::nullopt, c10::kCPU, false, output_scale, output_zero_point, - c10::nullopt); + std::nullopt); } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); @@ -2463,7 +2463,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROper in0_t, dtype, layout, device, pin_memory, memory_format); } auto& out_t = p_node->Output(0).toTensor(); - at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); at::native::fill_out(out_t, in1_s); }; }); @@ -2528,7 +2528,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator { const auto layout = p_node->Input(2).toOptional(); if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) { p_node->Output(0) = at::compositeexplicitautograd::zeros( - size, dtype, layout, c10::nullopt, c10::nullopt); + size, dtype, layout, std::nullopt, std::nullopt); return; } auto& out_t = p_node->Output(0).toTensor(); @@ -2709,7 +2709,7 @@ unsigned char abs_if_signed(unsigned char val) { // Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor void signed_log1p_out(at::Tensor& out, const at::Tensor& input) { - at::native::resize_(out, input.sizes(), c10::nullopt); + at::native::resize_(out, input.sizes(), std::nullopt); const auto input_contig = input.expect_contiguous(); auto output_contig = out.expect_contiguous(); @@ -2750,7 +2750,7 @@ REGISTER_OPERATOR_FUNCTOR( signed_log1p_out(out, input); return; } - at::native::resize_(out, input.sizes(), c10::nullopt); + at::native::resize_(out, input.sizes(), std::nullopt); int64_t nn = input.numel(); te->call({out.data_ptr(), input.data_ptr(), &nn}); }; diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index 362837e7ce78f0..623340daec068b 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -57,8 +57,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t) { c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - c10::nullopt, - c10::nullopt); + std::nullopt, + std::nullopt); } inline at::Tensor create_empty_from( @@ -69,20 +69,20 @@ inline at::Tensor create_empty_from( c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - c10::nullopt, - c10::nullopt); + std::nullopt, + std::nullopt); } inline at::Tensor create_empty(c10::ScalarType dtype) { return at::detail::empty_cpu( - {0}, dtype, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + {0}, dtype, std::nullopt, std::nullopt, std::nullopt, std::nullopt); } inline at::Tensor create_empty_from( const at::Tensor& t, c10::ScalarType dtype) { return at::detail::empty_cpu( - {0}, dtype, t.layout(), t.device(), c10::nullopt, c10::nullopt); + {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt); } inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { @@ -91,8 +91,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { c10::typeMetaToScalarType(t.dtype()), layout, t.device(), - c10::nullopt, - c10::nullopt); + std::nullopt, + std::nullopt); } inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { @@ -101,8 +101,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { c10::typeMetaToScalarType(t.dtype()), t.layout(), device, - c10::nullopt, - c10::nullopt); + std::nullopt, + std::nullopt); } inline at::Tensor create_empty_from( @@ -113,7 +113,7 @@ inline at::Tensor create_empty_from( c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - c10::nullopt, + std::nullopt, memory_format); } @@ -122,7 +122,7 @@ inline at::Tensor create_empty_from( c10::ScalarType dtype, c10::MemoryFormat memory_format) { return at::detail::empty_cpu( - {0}, dtype, t.layout(), t.device(), c10::nullopt, memory_format); + {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format); } inline bool checkResizedDataPtr(at::Tensor& t) { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 6aa65c528a42b8..92d901e43a5d21 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1609,7 +1609,7 @@ static void loadModule(const CompilationUnit& module) { static void loadFunctions() { for (const std::string& str : functions) { - compilation_unit.define(c10::nullopt, str, nativeResolver(), nullptr); + compilation_unit.define(std::nullopt, str, nativeResolver(), nullptr); } loadModule(compilation_unit); } @@ -1635,7 +1635,7 @@ std::optional gradientInfoForSchema( return sym_script_it->second; } } - return c10::nullopt; + return std::nullopt; } bool hasGradientInfoForSchema(const FunctionSchema& schema) { diff --git a/torch/csrc/jit/runtime/symbolic_script.h b/torch/csrc/jit/runtime/symbolic_script.h index 271bf66916f3d6..0715f0deeb1208 100644 --- a/torch/csrc/jit/runtime/symbolic_script.h +++ b/torch/csrc/jit/runtime/symbolic_script.h @@ -2,9 +2,9 @@ // This file is temporary until native_functions.yaml and derivatives.yaml are // merged. Ideally this should all go into native_functions.yaml -#include #include #include +#include namespace torch::jit { struct GradientPair { diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index ddea031aba73c8..f8cfca26c702a6 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -365,7 +365,7 @@ void loadFunctions() { [&](const std::string& name) -> std::shared_ptr { return src; }, 1); compilation_unit->define( - c10::nullopt, shape_compute_functions, resolver, nullptr); + std::nullopt, shape_compute_functions, resolver, nullptr); loadModule(*compilation_unit); } catch (...) { // Reset the cache and compilation unit so that we don't get weird errors @@ -391,7 +391,7 @@ std::optional> shapeComputeGraphForSchema( } GRAPH_DEBUG("Could not find schema: ", schema); - return c10::nullopt; + return std::nullopt; } TORCH_API std::optional boundedGraphsForSchema( @@ -406,7 +406,7 @@ TORCH_API std::optional boundedGraphsForSchema( return cache_it->second; } - return c10::nullopt; + return std::nullopt; } void RegisterShapeComputeGraphForSchema( diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 4a326285b29740..2bc464a0de172c 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -173,7 +173,7 @@ std::optional InlinedCallStackDeserializer:: const c10::IValue& iv, const std::shared_ptr& cu) { if (iv.isNone()) { - return c10::nullopt; + return std::nullopt; } auto tup = iv.toTuple(); auto it = cached_module_instance_info_.find(tup); diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 6ef9bdbf4abfac..2cfe34cd4abd2a 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -21,6 +20,7 @@ #include #include #include +#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof") #include diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index 9f194cd0ad31b7..4b895f9d657b44 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -166,7 +166,7 @@ mobile::Code compileGraphToMobileCode( // and is not allowed. For an operator with num_args = -1, it means the // number of arguments is not available for this operator, we don't do any // backward compatibility adaptation at runtime. - std::optional num_args = c10::nullopt; + std::optional num_args = std::nullopt; auto it = op_to_specified_args.find(unique_name); if (it != op_to_specified_args.end()) { num_args = it->second; diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 5bd7714c4e8d20..779e63a8436092 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -259,7 +259,7 @@ std::pair getFunctionTuple( if (namedType && namedType->name()) { return type_name_uniquer_.getUniqueName(namedType).qualifiedName(); } - return c10::nullopt; + return std::nullopt; }; auto makeArgTuple = [&](const std::vector& args) { @@ -765,7 +765,7 @@ std::optional type_printer( if (namedType && namedType->name()) { return type_name_uniquer.getUniqueName(namedType).qualifiedName(); } - return c10::nullopt; + return std::nullopt; } } // namespace diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 5a47fe900f3fdc..e1ad60afa5c387 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -69,7 +69,7 @@ auto print_type(const c10::Type& t) -> std::optional { if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return c10::nullopt; + return std::nullopt; } class FlatbufferSerializer { @@ -306,7 +306,7 @@ flatbuffers::Offset FlatbufferSerializer:: if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return c10::nullopt; + return std::nullopt; }; flatbuffers::Offset schema_offset = 0; diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index b090a1c80a3cd4..2da1e639ee80a7 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -21,19 +21,19 @@ class DeserializationStorageContext; TORCH_API Module import_ir_module( std::shared_ptr cu, const std::string& filename, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( std::shared_ptr cu, std::istream& in, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( std::shared_ptr cu, std::unique_ptr rai, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( @@ -80,7 +80,7 @@ TORCH_API Module import_ir_module( /// `torch::jit::ExportModule` in C++. TORCH_API Module load( std::istream& in, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -96,7 +96,7 @@ TORCH_API Module load( /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( const std::string& filename, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -112,7 +112,7 @@ TORCH_API Module load( /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( std::shared_ptr rai, - std::optional device = c10::nullopt, + std::optional device = std::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -131,17 +131,17 @@ TORCH_API Module parse_and_initialize_jit_module( std::shared_ptr data, size_t size, ExtraFilesMap& extra_files, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); TORCH_API Module load_jit_module_from_file( const std::string& filename, ExtraFilesMap& extra_files, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); TORCH_API Module load_jit_module_from_stream( std::istream& in, ExtraFilesMap& extra_files, - std::optional device = c10::nullopt); + std::optional device = std::nullopt); TORCH_API Module parse_and_initialize_jit_module( std::shared_ptr data, diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp index f67c2a22e9eb13..017ae5bd3da7cf 100644 --- a/torch/csrc/jit/serialization/import_source.cpp +++ b/torch/csrc/jit/serialization/import_source.cpp @@ -372,7 +372,7 @@ std::optional SourceImporterImpl:: if (replacements.count(demangled_classname)) { auto lhs = Var(assign.lhs()); if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { - return c10::nullopt; + return std::nullopt; } auto type = Var(assign.type().get()); @@ -389,7 +389,7 @@ std::optional SourceImporterImpl:: assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); } } - return c10::nullopt; + return std::nullopt; } void SourceImporterImpl::importClass( diff --git a/torch/csrc/jit/serialization/import_source.h b/torch/csrc/jit/serialization/import_source.h index 9b364f379b4091..a86a1f91926df7 100644 --- a/torch/csrc/jit/serialization/import_source.h +++ b/torch/csrc/jit/serialization/import_source.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -13,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -66,7 +66,7 @@ struct SourceImporterImpl : public Resolver, std::shared_ptr cu_; std::unordered_map> env_; SourceLoader source_loader_; - std::optional version_ = c10::nullopt; + std::optional version_ = std::nullopt; std::unordered_set loaded_sources_; // named types and functions loaded from a file but not yet defined because // their type has not been requested yet. diff --git a/torch/csrc/jit/serialization/pickle.cpp b/torch/csrc/jit/serialization/pickle.cpp index be36a4e2d8dd5e..c05bf330e7af3c 100644 --- a/torch/csrc/jit/serialization/pickle.cpp +++ b/torch/csrc/jit/serialization/pickle.cpp @@ -92,9 +92,9 @@ IValue pickle_load(const std::vector& data) { "data", /*pickle_prefix=*/"", /*tensor_prefix=*/"", - /*type_resolver=*/c10::nullopt, - /*obj_loader=*/c10::nullopt, - /*device=*/c10::nullopt, + /*type_resolver=*/std::nullopt, + /*obj_loader=*/std::nullopt, + /*device=*/std::nullopt, reader); #else AT_ERROR( diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 173ab5c13e5da4..04d3fc9a435614 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -605,7 +605,7 @@ std::optional type_printer(const c10::Type& type) { if (auto dyn = type.castRaw()) { return dyn->fallback()->annotation_str(type_printer); } - return c10::nullopt; + return std::nullopt; } } // namespace diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index f1b0865032c392..2292f11fd555ea 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1725,7 +1725,7 @@ static std::optional printType( if (namedType && namedType->name()) { return type_name_uniquer.getUniqueName(namedType).qualifiedName(); } - return c10::nullopt; + return std::nullopt; } void jitModuleToPythonCodeAndConstants( diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp index 118becd20dc7c6..6892493312b002 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.cpp +++ b/torch/csrc/jit/serialization/source_range_serialization.cpp @@ -68,7 +68,7 @@ std::shared_ptr SourceRangeDeserializer::deserialize_source( const auto& textIndex = tup_elems[0].toIntList(); int64_t fnameIndex = tup_elems[1].toInt(); int64_t starting_line_no_ = tup_elems[2].toInt(); - std::optional filename = c10::nullopt; + std::optional filename = std::nullopt; TORCH_CHECK( (uint64_t)fnameIndex < text_table_.size(), @@ -248,7 +248,7 @@ std::optional ConcreteSourceRangeUnpickler:: return (entry - 1)->range; } - return c10::nullopt; + return std::nullopt; } TORCH_API void setShouldUseFormatWithStringTable( diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index e1464d0efc3ec0..1ba4d54c4d29ca 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -99,7 +99,7 @@ static std::optional bufSize(BufPtr buf) { size_t size = elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes(); for (auto& d : buf->dims()) { if (!d->isConstant()) { - return c10::nullopt; + return std::nullopt; } size = size * (*intValue(d)); } diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 5666097f2dd45b..ceab479dc87946 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1305,7 +1305,7 @@ std::optional evalInt(ExprPtr e) { return ExprEval(cast(ExprHandle(e))) .value(); } catch (std::runtime_error& err) { - return c10::nullopt; + return std::nullopt; } } diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index 8c8de89975750c..c410c902ea4e4e 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -6,11 +6,11 @@ #pragma once #include -#include #include #include #include #include +#include #include @@ -207,10 +207,10 @@ class TORCH_API Buf : public ExprNode { const std::string& name_hint, const std::vector& dims, Dtype dtype, - std::optional initializer = c10::nullopt, - std::optional> strides = c10::nullopt, - std::optional qscale = c10::nullopt, - std::optional qzero = c10::nullopt); + std::optional initializer = std::nullopt, + std::optional> strides = std::nullopt, + std::optional qscale = std::nullopt, + std::optional qzero = std::nullopt); // TODO: unique_name VarPtr base_handle() const { @@ -232,7 +232,7 @@ class TORCH_API Buf : public ExprNode { const std::vector& dims, Dtype dtype, ExprPtr initializer = nullptr, - std::optional> strides = c10::nullopt, + std::optional> strides = std::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr) : Buf(alloc(name_hint, kHandle), @@ -248,7 +248,7 @@ class TORCH_API Buf : public ExprNode { std::vector dims, Dtype dtype, ExprPtr initializer = nullptr, - std::optional> strides = c10::nullopt, + std::optional> strides = std::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr); diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index a3146ccfaff550..decfe0bceb3215 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -123,7 +123,7 @@ std::vector constructTensors( } } else { // handle quantized - std::vector> qdata(bufs_num, c10::nullopt); + std::vector> qdata(bufs_num, std::nullopt); for (const auto& qd : *qdataArg) { qdata[qd.first] = qd.second; } @@ -233,7 +233,7 @@ std::vector constructTensors2( } } else { // handle quantized - std::vector> qdata(bufs_in_num, c10::nullopt); + std::vector> qdata(bufs_in_num, std::nullopt); for (const auto& qd : *qdataArg) { qdata[qd.first - bufs_out_num] = qd.second; } @@ -993,10 +993,10 @@ void nnc_aten_upsample_nearest2d( x, (output_size_h != -1) ? std::optional({output_size_h, output_size_w}) - : c10::nullopt, + : std::nullopt, (scale_factor_h != -1.f) ? std::optional>( {scale_factor_h, scale_factor_w}) - : c10::nullopt); + : std::nullopt); memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel()); } @@ -1043,10 +1043,10 @@ void nnc_aten_upsample_nearest2d_out( x, (output_size_h != -1) ? std::optional({output_size_h, output_size_w}) - : c10::nullopt, + : std::nullopt, (scale_factor_h != -1.f) ? std::optional>( {scale_factor_h, scale_factor_w}) - : c10::nullopt); + : std::nullopt); buf_data[0] = r.data_ptr(); c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get()); buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get(); @@ -1089,7 +1089,7 @@ void nnc_aten_quantize_per_tensor_out( buf_dims, buf_strides, buf_dtypes, - c10::nullopt, + std::nullopt, bufs_out_num); // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) at::Tensor x = tensors[1]; @@ -1214,7 +1214,7 @@ void nnc_aten_conv1d_out( buf_dims, buf_strides, buf_dtypes, - c10::nullopt, + std::nullopt, bufs_out_num); at::Tensor r; diff --git a/torch/csrc/jit/tensorexpr/external_functions.h b/torch/csrc/jit/tensorexpr/external_functions.h index 1fd90a3f056b8a..9dc859d2247158 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.h +++ b/torch/csrc/jit/tensorexpr/external_functions.h @@ -75,7 +75,7 @@ std::vector constructTensors( int64_t* buf_strides, int8_t* buf_dtypes, std::optional>> qdataArg = - c10::nullopt); + std::nullopt); std::vector constructTensors2( int64_t bufs_in_num, @@ -85,7 +85,7 @@ std::vector constructTensors2( int64_t* buf_strides, int8_t* buf_dtypes, std::optional>> qdataArg = - c10::nullopt, + std::nullopt, size_t bufs_out_num = 0); #ifdef C10_MOBILE diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index 01511b2b4d8c5c..0699dfd63da543 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -351,7 +351,7 @@ static std::optional inferScalarType(Node* n) { if (tt->scalarType() && *tt->scalarType() != scalar_type) { GRAPH_DEBUG( "Inputs of ", n, " have different scalar types, cannot fixup!"); - return c10::nullopt; + return std::nullopt; } } } @@ -369,7 +369,7 @@ static std::optional inferDevice(Node* n) { } if (tt->device() && *tt->device() != device) { GRAPH_DEBUG("Inputs of ", n, " have different devices, cannot fixup!"); - return c10::nullopt; + return std::nullopt; } } } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 89c3f96aba6e3a..90c5400472514a 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -367,7 +367,7 @@ inline std::optional intValue(const ExprPtr& e) { } AT_FORALL_INT_TYPES(TYPE_CASE); #undef TYPE_CASE - return c10::nullopt; + return std::nullopt; } inline std::optional intValue(const ExprHandle& e) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index afb7aefdda652f..b69d167dba535f 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -1885,7 +1885,7 @@ static std::optional isModRound(TermPtr e) { if (!mod) { mod = to(m); } else { - return c10::nullopt; + return std::nullopt; } } else { // Take care of special cases before multiplying the scalar and variable. @@ -1911,14 +1911,14 @@ static std::optional isModRound(TermPtr e) { if (!mod) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return c10::nullopt; + return std::nullopt; } mod_divisor = IRSimplifier::simplify(mod->rhs()); other = mod->lhs(); if (!(div = to
(other))) { - return c10::nullopt; + return std::nullopt; } divisor = IRSimplifier::simplify(div->rhs()); @@ -1953,16 +1953,16 @@ static std::optional isModRound(TermPtr e) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) denom = IRSimplifier::simplify(alloc
(other, c)); } else { - return c10::nullopt; + return std::nullopt; } } else { - return c10::nullopt; + return std::nullopt; } } // Deny cases in which divisor=1. Such cases are considered as Mods. if (divisor->isConstant() && immediateEquals(divisor, 1)) { - return c10::nullopt; + return std::nullopt; } if (!scalar) { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index d18a3d65f21ed0..81c171d5671175 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -129,12 +129,12 @@ bool& getOptConditionals() { std::optional pickDeviceType( const at::ArrayRef& inputs) { - std::optional device = c10::nullopt; + std::optional device = std::nullopt; for (auto const& input : inputs) { auto tt = input->type()->cast(); if (tt && tt->device()) { if (device && *device != *tt->device()) { - return c10::nullopt; + return std::nullopt; } device = *tt->device(); } @@ -144,7 +144,7 @@ std::optional pickDeviceType( static std::optional pickDeviceType( const std::shared_ptr& graph) { - std::optional device = c10::nullopt; + std::optional device = std::nullopt; for (auto const& node : graph->nodes()) { for (auto const& input : node->inputs()) { if (auto tt = input->type()->cast()) { @@ -184,10 +184,10 @@ static std::optional getTensorInfoJit(torch::jit::Value* v) { c10::ScalarType dtype = c10::ScalarType::Float; if (!it) { - return c10::nullopt; + return std::nullopt; } if (!it->isComplete()) { - return c10::nullopt; + return std::nullopt; } if (it->scalarType()) { // TODO: ideally we should be strict here and return nullopt if the dtype is @@ -197,7 +197,7 @@ static std::optional getTensorInfoJit(torch::jit::Value* v) { } auto concrete_sizes = it->sizes().concrete_sizes(); if (!concrete_sizes) { - return c10::nullopt; + return std::nullopt; } return TensorInfo{*concrete_sizes, dtype}; } @@ -712,7 +712,7 @@ static std::optional tripCount(ForPtr loop) { if (auto val = to(tc.node())) { return val->value(); } - return c10::nullopt; + return std::nullopt; } // Prune innermost loops until iterations satisfies a minimum grain size. @@ -1314,7 +1314,7 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides( BufPtr buf = bufs_.at(v); TORCH_INTERNAL_ASSERT(buf != nullptr); TORCH_INTERNAL_ASSERT(tt != nullptr); - TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != c10::nullopt); + TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != std::nullopt); auto stride_desc = getSymbolicStrideDesc(v); TORCH_INTERNAL_ASSERT(stride_desc.size() == 1); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index dec03637847e29..1cae1fe9b2dc22 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -85,15 +85,15 @@ C10_DEFINE_bool( namespace torch::jit::tensorexpr { std::optional& LLVMTargetTriple() { - static std::optional triple = c10::nullopt; + static std::optional triple = std::nullopt; return triple; } std::optional& LLVMTargetCPU() { - static std::optional cpu = c10::nullopt; + static std::optional cpu = std::nullopt; return cpu; } std::optional& LLVMTargetAttrs() { - static std::optional attrs = c10::nullopt; + static std::optional attrs = std::nullopt; return attrs; } bool& LLVMAOTWorkflow() { diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h index 74271fa879f3de..1d96b4dd0467e3 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.h +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -7,7 +7,7 @@ #include #include -#include +#include #include #include @@ -27,9 +27,9 @@ class TORCH_API LLVMCodeGen : public CodeGen { at::Device device = at::kCPU, const std::string& kernel_func_name = "func", Dtype dtype = kInt, - std::optional triple = c10::nullopt, - std::optional cpu = c10::nullopt, - std::optional attrs = c10::nullopt); + std::optional triple = std::nullopt, + std::optional cpu = std::nullopt, + std::optional attrs = std::nullopt); explicit LLVMCodeGen(StmtPtr stmt); LLVMCodeGen() = delete; @@ -126,9 +126,9 @@ struct TORCH_API LLVMCodeGenBuilder { at::Device device_ = at::kCPU; std::string kernelFuncName_ = "func"; Dtype dtype_ = kInt; - std::optional triple_ = c10::nullopt; - std::optional cpu_ = c10::nullopt; - std::optional attrs_ = c10::nullopt; + std::optional triple_ = std::nullopt; + std::optional cpu_ = std::nullopt; + std::optional attrs_ = std::nullopt; }; TORCH_API std::optional& LLVMTargetTriple(); diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index 98238e0043885f..beadbdd5e537e7 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -3,8 +3,8 @@ #ifdef TORCH_ENABLE_LLVM #include #include -#include #include +#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index bdf313f0ad0515..bfce006d55177e 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -51,7 +51,7 @@ Tensor conv2d_depthwise_static( Tensor conv = Reduce( "conv2d_depthwise", {N, K, OH, OW}, - c10::nullopt, // TODO + std::nullopt, // TODO Sum(), [&](const std::vector& v) { return init_func(v); }, [&](const std::vector& v) { @@ -123,7 +123,7 @@ Tensor conv2d_depthwise_dynamic( return Reduce( "conv2d_depthwise", {N, K, OH, OW}, - c10::nullopt, // TODO + std::nullopt, // TODO Sum(), [&](const std::vector& v) { return init_func(v); }, [&](const std::vector& v) { diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index 938cab6ffd8830..6ff6dd733885be 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -165,7 +165,7 @@ std::optional getTensorInfo(BufHandle b) { for (auto dim : b.dims()) { auto val = intValue(dim.node()); if (!val) { - return c10::nullopt; + return std::nullopt; } dims.push_back(*val); } diff --git a/torch/csrc/jit/tensorexpr/operators/pointwise.h b/torch/csrc/jit/tensorexpr/operators/pointwise.h index 0ce10424b3d30a..589674117c1a00 100644 --- a/torch/csrc/jit/tensorexpr/operators/pointwise.h +++ b/torch/csrc/jit/tensorexpr/operators/pointwise.h @@ -9,7 +9,7 @@ namespace tensorexpr { TORCH_API Tensor computeSign( const std::vector& inputs, const std::vector& outputShape, - std::optional> outputStrides = c10::nullopt); + std::optional> outputStrides = std::nullopt); Tensor computeOneOperand( const std::string& name, diff --git a/torch/csrc/jit/tensorexpr/operators/quantization.cpp b/torch/csrc/jit/tensorexpr/operators/quantization.cpp index 66c0688538a1d7..204a4c2211f7a9 100644 --- a/torch/csrc/jit/tensorexpr/operators/quantization.cpp +++ b/torch/csrc/jit/tensorexpr/operators/quantization.cpp @@ -171,7 +171,7 @@ Tensor computeQuantizePerTensor( ExprHandleVectorToExprVector(outputShape), dtype, nullptr, - c10::nullopt, + std::nullopt, qscale.node(), qzero.node()); return Tensor(buf, vars, e.node()); @@ -731,7 +731,7 @@ Tensor computeUpsampleNearest2d( "upsample_nearest2d", outputShape, Dtype(*outputType), - c10::nullopt, // initializer + std::nullopt, // initializer fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); }), ExprHandle(A.node()->qscale()), ExprHandle(A.node()->qzero())); diff --git a/torch/csrc/jit/tensorexpr/operators/softmax.cpp b/torch/csrc/jit/tensorexpr/operators/softmax.cpp index 9bd82afd177d46..f73e06086d3d9e 100644 --- a/torch/csrc/jit/tensorexpr/operators/softmax.cpp +++ b/torch/csrc/jit/tensorexpr/operators/softmax.cpp @@ -103,7 +103,7 @@ Tensor computeSoftmax( auto max = Reduce( "aten_softmax_max", non_softmax_dims, - c10::nullopt, + std::nullopt, Maximum(dtype), [&](ParameterList& indices) { return tensorOrConstant( @@ -113,7 +113,7 @@ Tensor computeSoftmax( auto e = Compute( "aten_softmax_exp", outputShape, - c10::nullopt, + std::nullopt, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); @@ -122,7 +122,7 @@ Tensor computeSoftmax( auto sum = Reduce( "aten_softmax_sum", non_softmax_dims, - c10::nullopt, + std::nullopt, Sum(), [&](ParameterList& indices) { return e.load(move_softmax_dim_index_to_pos(indices)); @@ -130,7 +130,7 @@ Tensor computeSoftmax( {outputShape[softmax_dim]}); if (!log_softmax) { auto result = Compute( - "aten_softmax", outputShape, c10::nullopt, [&](ParameterList& indices) { + "aten_softmax", outputShape, std::nullopt, [&](ParameterList& indices) { return e.load(indices) / sum.load(remove_softmax_dim_index(indices)); }); return Tensor( @@ -142,12 +142,12 @@ Tensor computeSoftmax( auto log_sum = Compute( "aten_softmax_log_sum", non_softmax_dims, - c10::nullopt, + std::nullopt, [&](ParameterList& indices) { return log(sum.load(indices)); }); auto result = Compute( "aten_log_softmax", outputShape, - c10::nullopt, + std::nullopt, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 5bc734bb80b838..5a9af09f9d87eb 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -103,14 +103,14 @@ Tensor Compute( const std::function&)>& body_func) { std::vector args = create_index_vars(dims); ExprHandle body = body_func(args); - BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function&)>& body_func) { - return Compute(name, dims, c10::nullopt, body_func); + return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( @@ -124,14 +124,14 @@ Tensor Compute( std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0]); - BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function& body_func) { - return Compute(name, dims, c10::nullopt, body_func); + return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( @@ -145,7 +145,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1]); - BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -153,7 +153,7 @@ Tensor Compute( const std::vector& dims, const std::function& body_func) { - return Compute(name, dims, c10::nullopt, body_func); + return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( @@ -168,7 +168,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2]); - BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -177,7 +177,7 @@ Tensor Compute( const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { - return Compute(name, dims, c10::nullopt, body_func); + return Compute(name, dims, std::nullopt, body_func); } Tensor Compute( @@ -194,7 +194,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2], args[3]); - BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -205,7 +205,7 @@ Tensor Compute( const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { - return Compute(name, dims, c10::nullopt, body_func); + return Compute(name, dims, std::nullopt, body_func); } Tensor Reduce( @@ -229,7 +229,7 @@ Tensor Reduce( const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims) { - return Reduce(name, dims, c10::nullopt, reducer, buffer, reduce_dims); + return Reduce(name, dims, std::nullopt, reducer, buffer, reduce_dims); } Tensor Reduce( @@ -253,7 +253,7 @@ Tensor Reduce( const Reducer& reducer, Tensor tensor, const std::vector& reduce_dims) { - return Reduce(name, dims, c10::nullopt, reducer, tensor, reduce_dims); + return Reduce(name, dims, std::nullopt, reducer, tensor, reduce_dims); } } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 7b589d0974b37b..3fb55152b70d64 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -161,7 +161,7 @@ Tensor Reduce( if (reduce_vars.empty()) { ExprHandle body = Reducer::getReduceBody(body_func, vars); BufHandle func_result = Buf::make( - func_name, dims, body.dtype(), c10::nullopt, std::move(strides)); + func_name, dims, body.dtype(), std::nullopt, std::move(strides)); return Tensor(std::move(func_result), vars, std::move(body)); } @@ -206,7 +206,7 @@ Tensor Reduce( return Reduce( func_name, dims, - c10::nullopt, + std::nullopt, reducer, init_func, body_func, @@ -238,7 +238,7 @@ Tensor Reduce( const BodyFunc& body_func, const std::vector& reduce_dims) { return Reduce( - func_name, dims, c10::nullopt, reducer, body_func, reduce_dims); + func_name, dims, std::nullopt, reducer, body_func, reduce_dims); } // Overload which allows inline lambda functions for the body_func. @@ -259,7 +259,7 @@ Tensor Reduce( const Reducer& reducer, const BodyFunc&& body_func, const std::vector& reduce_dims) { - return Reduce(func_name, dims, c10::nullopt, reducer, body_func, reduce_dims); + return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims); } TORCH_API Tensor Reduce( diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index ec0011f40d775c..027eb2aa0acf6b 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -10,13 +10,13 @@ // API modified from llvm::FileCheck #include -#include #include #include #include #include #include #include +#include #include #include @@ -43,13 +43,13 @@ struct Check { Check( CheckType type, std::string str, - std::optional count = c10::nullopt) + std::optional count = std::nullopt) : type_(type), count_(count), search_str_(std::move(str)) {} Check( CheckType type, c10::string_view str, - std::optional count = c10::nullopt) + std::optional count = std::nullopt) : Check(type, std::string(str.begin(), str.end()), count) {} CheckType type_; @@ -234,7 +234,7 @@ struct FileCheckImpl { TORCH_API void addCheck( CheckType type, const std::string& s, - std::optional count = c10::nullopt) { + std::optional count = std::nullopt) { addCheck(Check(type, s, count)); } @@ -264,7 +264,7 @@ struct FileCheckImpl { } size_t end_check_string = suffix_pos + check_suffix.size(); CheckType type = check_pair.first; - std::optional count = c10::nullopt; + std::optional count = std::nullopt; auto end_line = source->text_str().find("\n", end_check_string); bool exactly = false; if (type == CHECK_COUNT) { diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index 6d146ca0881ceb..3eac703be175fc 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -2,10 +2,10 @@ #include #include -#include #include #include #include +#include namespace torch { namespace lazy { @@ -60,7 +60,7 @@ std::optional GetBackendDevice(at::ITensorListRef tensors) { return lt->GetDevice(); } } - return c10::nullopt; + return std::nullopt; } std::optional GetBackendDevice(at::TensorList tensors) { @@ -71,19 +71,19 @@ std::optional GetBackendDevice(const at::Tensor& tensor) { if (auto lt = TryGetLtcTensor(tensor)) { return lt->GetDevice(); } - return c10::nullopt; + return std::nullopt; } std::optional GetBackendDevice( const std::optional& device) { if (device) { - return c10::make_optional(atenDeviceToBackendDevice(*device)); + return std::make_optional(atenDeviceToBackendDevice(*device)); } - return c10::nullopt; + return std::nullopt; } std::optional GetBackendDevice() { - return c10::nullopt; + return std::nullopt; } } // namespace lazy diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index e80c800a2ecead..fdfc2ac15d9a89 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include namespace c10 { struct Device; diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 981e1667772944..570dc942e6a68a 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -1,12 +1,12 @@ #pragma once #include -#include #include #include #include #include #include +#include #include // This file is part of the backend interface. So, ops shouldn't be added or @@ -61,7 +61,7 @@ struct IrBuilder { virtual NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, - const std::optional& stype = c10::nullopt) const = 0; + const std::optional& stype = std::nullopt) const = 0; virtual NodePtr MakeTensorList(const OpList& inputs) const = 0; virtual NodePtr MakeGeneric( const OpKind& op, @@ -96,7 +96,7 @@ static inline NodePtr MakeExpand( static inline NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, - const std::optional& stype = c10::nullopt) { + const std::optional& stype = std::nullopt) { return getIrBuilder()->MakeCast(input0, dtype, stype); } static inline NodePtr MakeTensorList(const OpList& inputs) { diff --git a/torch/csrc/lazy/core/ir_dump_util.cpp b/torch/csrc/lazy/core/ir_dump_util.cpp index a4fb11761a67ce..d81d810a54e98f 100644 --- a/torch/csrc/lazy/core/ir_dump_util.cpp +++ b/torch/csrc/lazy/core/ir_dump_util.cpp @@ -1,10 +1,10 @@ #include -#include #include #include #include #include +#include #include #include @@ -37,7 +37,7 @@ std::optional ParseAttrTag( // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful if (!std::regex_search( node_string.begin() + pos, node_string.end(), match, tag_regex)) { - return c10::nullopt; + return std::nullopt; } std::string::size_type vpos = match[1].second - node_string.begin() + 1; @@ -102,7 +102,7 @@ std::optional GetRootNodeId( const std::unordered_map& roots_ids) { auto it = roots_ids.find(node); if (it == roots_ids.end()) { - return c10::nullopt; + return std::nullopt; } return it->second; } diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index 569cd5ee5e0a18..b01b5ead3434b3 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -695,7 +695,7 @@ std::vector LazyGraphExecutor::SetTensorData( // resets the ir_value. We have already done the resetting as part // of ExtractIRAndPrepareTensorData to overlap with previous execution. tensor->data()->handle = handle; - tensor->data()->tensor_data = c10::nullopt; + tensor->data()->tensor_data = std::nullopt; } tensors_data.emplace_back(std::move(handle)); } diff --git a/torch/csrc/lazy/core/shape.cpp b/torch/csrc/lazy/core/shape.cpp index 939e2745ed3938..bf49cfacb99f61 100644 --- a/torch/csrc/lazy/core/shape.cpp +++ b/torch/csrc/lazy/core/shape.cpp @@ -78,7 +78,7 @@ static c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) { std::vector> symbolic_dims; for (size_t i = 0; i < sizes.size(); i++) { if (is_symbolic->at(i)) { - symbolic_dims.emplace_back(c10::nullopt); + symbolic_dims.emplace_back(std::nullopt); } else { symbolic_dims.emplace_back(sizes.at(i)); } @@ -114,7 +114,7 @@ void applySymbolicShapesOnLT( auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args); if (!res_symbolic) { for (auto& result_shape : result_shapes) { - result_shape = result_shape.with_symbolic_dims(c10::nullopt); + result_shape = result_shape.with_symbolic_dims(std::nullopt); } } else { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/lazy/core/shape.h b/torch/csrc/lazy/core/shape.h index 63566619fd1493..99e4a892bc589f 100644 --- a/torch/csrc/lazy/core/shape.h +++ b/torch/csrc/lazy/core/shape.h @@ -19,7 +19,7 @@ class TORCH_API Shape { Shape( at::ScalarType scalar_type, c10::ArrayRef sizes, - std::optional> is_symbolic = c10::nullopt); + std::optional> is_symbolic = std::nullopt); std::string to_string() const; @@ -64,7 +64,7 @@ class TORCH_API Shape { // Stores which dimmensions are symbolic // If nullopt, either it hasn't been initialized or the symbolic // dimmensions are not calculatable - std::optional> is_symbolic_ = c10::nullopt; + std::optional> is_symbolic_ = std::nullopt; }; TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape); diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 77eeaaa563187f..76ddea597a784a 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -6,11 +6,11 @@ #include #include #include -#include #include #include #include #include +#include #include namespace torch { diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index ba0571f87df4d3..972af7dafc8baa 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -143,13 +143,13 @@ void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) { // trimming. AssignIrValue(Value()); if (sync) { - data()->tensor_data = c10::nullopt; + data()->tensor_data = std::nullopt; } } void LazyTensor::SetIrValue(Value ir_value) { data()->handle = nullptr; - data()->tensor_data = c10::nullopt; + data()->tensor_data = std::nullopt; AssignIrValue(std::move(ir_value)); TryLimitGraphSize(); } @@ -158,7 +158,7 @@ void LazyTensor::SetInPlaceIrValue(Value ir_value) { auto tensor_shape = shape(); if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) { ir_value = - MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt); + MakeCast(ir_value, tensor_shape.Get().scalar_type(), std::nullopt); } SetIrValue(std::move(ir_value)); } @@ -253,7 +253,7 @@ at::Tensor LazyTensor::ToTensor(bool detached) { if (data()->ir_value || data()->handle != nullptr) { // If we have other authoritive sources, just drop our reference and // transfer it to the caller. - data()->tensor_data = c10::nullopt; + data()->tensor_data = std::nullopt; } else { // Otherwise we need to make a copy to prevent the caller changing our // version. diff --git a/torch/csrc/lazy/core/unique.h b/torch/csrc/lazy/core/unique.h index fc09c8d71d7d8d..3088da160860b7 100644 --- a/torch/csrc/lazy/core/unique.h +++ b/torch/csrc/lazy/core/unique.h @@ -5,7 +5,7 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/lazy/core/util.h b/torch/csrc/lazy/core/util.h index e535e5365f2277..bfd68b73355dfc 100644 --- a/torch/csrc/lazy/core/util.h +++ b/torch/csrc/lazy/core/util.h @@ -9,8 +9,8 @@ #include #include -#include #include +#include namespace torch { namespace lazy { @@ -114,7 +114,7 @@ std::optional> ToOptionalVector( if (arrayRef) { return arrayRef->vec(); } - return c10::nullopt; + return std::nullopt; } template diff --git a/torch/csrc/lazy/python/python_util.cpp b/torch/csrc/lazy/python/python_util.cpp index 90d9797e3fd357..1ae663c519f562 100644 --- a/torch/csrc/lazy/python/python_util.cpp +++ b/torch/csrc/lazy/python/python_util.cpp @@ -13,12 +13,12 @@ namespace lazy { std::optional GetPythonFrameTop() { if (!Py_IsInitialized()) { - return c10::nullopt; + return std::nullopt; } pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); if (frame == nullptr) { - return c10::nullopt; + return std::nullopt; } SourceLocation loc; auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); diff --git a/torch/csrc/lazy/python/python_util.h b/torch/csrc/lazy/python/python_util.h index 456aafa8809716..271c694ee35ddc 100644 --- a/torch/csrc/lazy/python/python_util.h +++ b/torch/csrc/lazy/python/python_util.h @@ -1,7 +1,7 @@ #pragma once -#include #include #include +#include #include namespace torch { diff --git a/torch/csrc/lazy/ts_backend/ir_builder.h b/torch/csrc/lazy/ts_backend/ir_builder.h index c5382923744345..9fff33135a5c87 100644 --- a/torch/csrc/lazy/ts_backend/ir_builder.h +++ b/torch/csrc/lazy/ts_backend/ir_builder.h @@ -34,7 +34,7 @@ struct TorchScriptIrBuilder : IrBuilder { const Value& input0, const at::ScalarType& dtype, const std::optional& stype = - c10::nullopt) const override { + std::nullopt) const override { return ReuseOrMakeNode(input0, dtype, stype); } NodePtr MakeTensorList(const OpList& inputs) const override { diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index 42acc2c5df10a2..a00ec260e5a145 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -137,7 +137,7 @@ std::optional compute_target_device( } } } - return c10::nullopt; + return std::nullopt; } } // namespace diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 78ae6a6f6e2e55..55d0b7f5a46543 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -39,10 +39,10 @@ at::Tensor CreateLtcTensor( std::optional GetLtcDevice( const std::optional& device) { if (!device) { - return c10::nullopt; + return std::nullopt; } if (device->type() != at::kLazy) { - return c10::nullopt; + return std::nullopt; } return torch::lazy::atenDeviceToBackendDevice(*device); } @@ -235,7 +235,7 @@ at::Tensor LazyNativeFunctions::_to_copy( // captured IR, or we will try to convert an eager tensor back to a lazy one // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so // we can safely drop the device argument - device = c10::nullopt; + device = std::nullopt; torch::lazy::NodePtr node = torch::lazy::ReuseNode( lazy_self->GetIrValue(), @@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided_symint( std::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); at::Tensor t = - empty_symint(sym_size, dtype, layout, device, pin_memory, c10::nullopt); + empty_symint(sym_size, dtype, layout, device, pin_memory, std::nullopt); auto size = C10_AS_INTARRAYREF_SLOW(sym_size); auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride); return t.as_strided(size, stride, /*storage_offset=*/0); diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index e5daea953c57dd..687e8bf28787a8 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -200,7 +200,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { if (io_type == tagToIOType(tag)) { out.emplace_back(std::move(input)); } else { - out.emplace_back(c10::nullopt); + out.emplace_back(std::nullopt); } }; @@ -223,7 +223,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { arg.emplace_back(decode_tensor()); } if (found_undefined) { - push_value(*tag_it, c10::nullopt); + push_value(*tag_it, std::nullopt); } else { push_value(Tag::TensorListBegin, std::move(arg)); } @@ -236,7 +236,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { case Tag::UndefinedTensor: case Tag::Other: - push_value(*tag_it, c10::nullopt); + push_value(*tag_it, std::nullopt); break; case Tag::TERMINATOR: diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 1c0a780370a9f0..71cb0c02bccc81 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -91,7 +91,7 @@ using op_input_t = std::variant< TensorMetadata, std::vector, c10::IValue, - c10::nullopt_t>; + std::nullopt_t>; // ============================================================================ // == ExtraFields ============================================================= diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index e5cae40c84e315..25f93a2663dfb5 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -458,7 +458,7 @@ void initPythonBindings(PyObject* module) { [&](const c10::IValue& v) { out.append(torch::jit::toPyObject(v)); }, - [&](const c10::nullopt_t&) { out.append(py::none()); }, + [&](const std::nullopt_t&) { out.append(py::none()); }, [&](const auto& v) { out.append(py::cast(v)); }), input); } diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 74d7877edadf14..8a3c4487ab7763 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -290,12 +290,12 @@ std::vector unwind() { std::optional> libraryFor(void* addr) { if (!addr) { - return c10::nullopt; + return std::nullopt; } std::shared_lock lock(cache_mutex_); const LibraryInfo* library_info = unwind_cache.findLibraryFor((uint64_t)addr); if (!library_info) { - return c10::nullopt; + return std::nullopt; } return std::make_pair( library_info->name(), (uint64_t)addr - library_info->load_bias()); diff --git a/torch/csrc/profiler/unwind/unwind.h b/torch/csrc/profiler/unwind/unwind.h index 1c302dfca445ff..bf93b88fa63dcb 100644 --- a/torch/csrc/profiler/unwind/unwind.h +++ b/torch/csrc/profiler/unwind/unwind.h @@ -1,7 +1,7 @@ #pragma once #include -#include #include +#include #include #include diff --git a/torch/csrc/profiler/unwind/unwind_error.h b/torch/csrc/profiler/unwind/unwind_error.h index ae3630057f6a4c..cca8f8d12187b8 100644 --- a/torch/csrc/profiler/unwind/unwind_error.h +++ b/torch/csrc/profiler/unwind/unwind_error.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include #include namespace torch::unwind { diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index b06a479e70cc5f..1a607909c45220 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -9,10 +9,10 @@ #include #include -#include #include #include #include +#include // TODO: replace with pytorch/rfcs#43 when it is ready. #define SOFT_ASSERT(cond, ...) \ diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 8d18180ed91955..6960034626d568 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -449,7 +449,7 @@ void py_set_default_dtype(PyObject* obj) { THPDtype_Check(obj), "invalid dtype object: only floating-point types are supported as the default type"); auto scalar_type = ((THPDtype*)obj)->scalar_type; - set_default_tensor_type(/*backend=*/c10::nullopt, scalar_type); + set_default_tensor_type(/*backend=*/std::nullopt, scalar_type); } c10::DispatchKey get_default_dispatch_key() { diff --git a/torch/csrc/utils/nested.cpp b/torch/csrc/utils/nested.cpp index 29ccf312851ea1..360abda078df57 100644 --- a/torch/csrc/utils/nested.cpp +++ b/torch/csrc/utils/nested.cpp @@ -82,7 +82,7 @@ at::Tensor nested_tensor_ctor( final_device = new_list[0].device(); } auto out = at::_nested_tensor_from_tensor_list( - new_list, final_dtype, c10::nullopt, final_device, pin_memory); + new_list, final_dtype, std::nullopt, final_device, pin_memory); out.requires_grad_(args_requires_grad); return out; } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 9aa80427929df0..a1a1638f9120be 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -268,7 +268,7 @@ static py::object dispatch_on_subclass( bool is_torch_function, const char* torch_function_name_str, std::optional maybe_mode_key = - c10::nullopt) { + std::nullopt) { py::object ret; for (auto& arg : overloaded_args) { py::object torch_function = @@ -1005,11 +1005,11 @@ std::string FunctionParameter::type_name() const { static inline std::optional parse_as_integer(const std::string& s) { if (s.empty()) - return c10::nullopt; + return std::nullopt; char* str_end = nullptr; long ans = strtol(s.c_str(), &str_end, 0); // *str_end == 0 if the entire string was parsed as an integer. - return (*str_end == 0) ? std::optional(ans) : c10::nullopt; + return (*str_end == 0) ? std::optional(ans) : std::nullopt; } /* diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 8966131f9825f2..85a4d52bc16df8 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -399,7 +399,7 @@ inline std::optional PythonArgs::optionalTensor(int i) { if (t.defined()) { return t; } else { - return c10::nullopt; + return std::nullopt; } } @@ -435,7 +435,7 @@ inline at::Scalar PythonArgs::scalarWithDefault( inline std::optional PythonArgs::scalarOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return scalar_slow(i); } @@ -771,7 +771,7 @@ inline at::ScalarType PythonArgs::scalartype(int i) { inline std::optional PythonArgs::scalartypeOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return scalartype(i); } @@ -796,7 +796,7 @@ inline at::Layout PythonArgs::layoutWithDefault( inline std::optional PythonArgs::layoutOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return layout(i); } @@ -837,7 +837,7 @@ inline at::Device PythonArgs::deviceWithDefault( inline std::optional PythonArgs::deviceOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return device(i); } @@ -863,7 +863,7 @@ inline std::vector parseDimnameList(PyObject* arg) { inline std::optional> PythonArgs:: toDimnameListOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return parseDimnameList(args[i]); } @@ -890,7 +890,7 @@ inline at::MemoryFormat PythonArgs::memoryformat(int i) { inline std::optional PythonArgs::memoryformatOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return memoryformat(i); } @@ -918,7 +918,7 @@ inline std::string PythonArgs::stringWithDefault( inline std::optional PythonArgs::stringOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return THPUtils_unpackString(args[i]); } @@ -936,7 +936,7 @@ inline c10::string_view PythonArgs::stringViewWithDefault( inline std::optional PythonArgs::stringViewOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return THPUtils_unpackStringView(args[i]); } @@ -990,26 +990,26 @@ inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) { inline std::optional PythonArgs::toInt64Optional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return toInt64(i); } inline std::optional PythonArgs::toSymIntOptional(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return toSymInt(i); } inline std::optional PythonArgs::toBoolOptional(int i) { if (!args[i]) { - return c10::nullopt; + return std::nullopt; } return toBool(i); } inline std::optional PythonArgs::toDoubleOptional(int i) { if (!args[i]) { - return c10::nullopt; + return std::nullopt; } return toDouble(i); } @@ -1071,7 +1071,7 @@ inline bool PythonArgs::isNone(int i) { inline std::optional PythonArgs::generator(int i) { if (!args[i]) - return c10::nullopt; + return std::nullopt; return reinterpret_cast(args[i])->cdata; } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index ec0af99842d2e5..2d18978018a1b4 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -65,8 +65,8 @@ static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { template inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { auto mb_key = std::string(key).empty() - ? c10::nullopt - : c10::make_optional(c10::parseDispatchKey(key)); + ? std::nullopt + : std::make_optional(c10::parseDispatchKey(key)); if (mb_key) { return torch::dispatch(*mb_key, std::forward(raw_f)); } else { @@ -217,7 +217,7 @@ static py::object ophandle_call_boxed( handle.schema(), std::move(args), kwargs, - /*self=*/c10::nullopt); + /*self=*/std::nullopt); { pybind11::gil_scoped_release no_gil_guard; handle.callBoxed(stack); @@ -264,7 +264,7 @@ void initDispatchBindings(PyObject* module) { handle.schema(), std::move(args), kwargs, - /*self=*/c10::nullopt); + /*self=*/std::nullopt); { pybind11::gil_scoped_release no_gil_guard; handle.redispatchBoxed(keyset, &stack); @@ -477,8 +477,8 @@ void initDispatchBindings(PyObject* module) { parseKind(kind), std::move(name), std::string(dispatch).empty() - ? c10::nullopt - : c10::make_optional(c10::parseDispatchKey(dispatch)), + ? std::nullopt + : std::make_optional(c10::parseDispatchKey(dispatch)), "/dev/null", // temporary workaround linenum); END_HANDLE_TH_ERRORS_PYBIND @@ -814,8 +814,8 @@ void initDispatchBindings(PyObject* module) { "_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() - ? c10::nullopt - : c10::make_optional(c10::parseDispatchKey(dispatch_key)); + ? std::nullopt + : std::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { @@ -830,7 +830,7 @@ void initDispatchBindings(PyObject* module) { try { return c10::parseDispatchKey(dispatch_key); } catch (const c10::Error& err) { - return c10::nullopt; + return std::nullopt; } }); @@ -838,8 +838,8 @@ void initDispatchBindings(PyObject* module) { "_dispatch_get_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() - ? c10::nullopt - : c10::make_optional(c10::parseDispatchKey(dispatch_key)); + ? std::nullopt + : std::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); std::vector names; @@ -888,7 +888,7 @@ void initDispatchBindings(PyObject* module) { "Expected device_type string to not have a device index; got ", device_type); return c10::toString( - c10::computeDispatchKey(c10::nullopt, c10::nullopt, device)); + c10::computeDispatchKey(std::nullopt, std::nullopt, device)); }); m.def("_are_functorch_transforms_active", []() { diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h index bc7b5c263e0d91..af63d1efba5458 100644 --- a/torch/csrc/utils/python_raii.h +++ b/torch/csrc/utils/python_raii.h @@ -1,5 +1,5 @@ -#include #include +#include #include namespace torch::impl { @@ -17,7 +17,7 @@ struct RAIIContextManager { } void exit() { - guard_ = c10::nullopt; + guard_ = std::nullopt; } private: @@ -50,7 +50,7 @@ struct DeprecatedRAIIContextManager { void enter() {} void exit() { - guard_ = c10::nullopt; + guard_ = std::nullopt; } private: diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 15738b1a67e16c..e82c30a8c98f75 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -144,7 +144,7 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { py::gil_scoped_acquire acquire; const auto& r = getPyObj().attr("maybe_as_int")(); if (r.is_none()) { - return c10::nullopt; + return std::nullopt; } else { return r.cast(); } diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp index 0caa5b254d279f..61eecc7cf0079e 100644 --- a/torch/csrc/utils/schema_info.cpp +++ b/torch/csrc/utils/schema_info.cpp @@ -8,7 +8,7 @@ void SchemaInfo::addArgumentValue( const at::IValue& value) { std::optional index = schema_.argumentIndexWithName(name); TORCH_INTERNAL_ASSERT( - index != c10::nullopt, "Schema has no argument named ", name); + index != std::nullopt, "Schema has no argument named ", name); value_map_[name] = value; alias_maps_current_ = false; } @@ -102,7 +102,7 @@ bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) { } bool SchemaInfo::has_argument(c10::string_view name) { - return schema_.argumentIndexWithName(name) != c10::nullopt; + return schema_.argumentIndexWithName(name) != std::nullopt; } bool SchemaInfo::is_mutable(c10::string_view name) { diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 4fd398d1a8fafd..e66c99bc4d4939 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -28,8 +28,8 @@ #include #include #include -#include #include +#include #include #include @@ -53,7 +53,7 @@ thread_local bool kOnlyLiftCPUTensors = false; TensorOptions build_options( c10::TensorOptions options, at::ScalarType scalar_type, - const std::optional& device = c10::nullopt) { + const std::optional& device = std::nullopt) { options = options.dtype(scalar_type); if (device.has_value()) { return options.device(device); @@ -1257,7 +1257,7 @@ void _validate_sparse_coo_tensor_args( Tensor values = internal_new_from_data( options, scalar_type, - c10::nullopt, + std::nullopt, r.pyobject(1), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1266,7 +1266,7 @@ void _validate_sparse_coo_tensor_args( Tensor indices = internal_new_from_data( values.options(), kLong, - c10::nullopt, + std::nullopt, r.pyobject(0), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1298,7 +1298,7 @@ void _validate_sparse_compressed_tensor_args( Tensor values = internal_new_from_data( options, scalar_type, - c10::nullopt, + std::nullopt, r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1307,7 +1307,7 @@ void _validate_sparse_compressed_tensor_args( Tensor compressed_indices = internal_new_from_data( values.options(), kInt, - c10::nullopt, + std::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1315,7 +1315,7 @@ void _validate_sparse_compressed_tensor_args( Tensor plain_indices = internal_new_from_data( values.options(), kInt, - c10::nullopt, + std::nullopt, r.pyobject(ARG_PLAIN_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1369,7 +1369,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor values = internal_new_from_data( options, scalar_type, - c10::nullopt, + std::nullopt, r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1378,7 +1378,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor compressed_indices = internal_new_from_data( values.options(), kInt, - c10::nullopt, + std::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1386,7 +1386,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor plain_indices = internal_new_from_data( values.options(), kInt, - c10::nullopt, + std::nullopt, r.pyobject(ARG_PLAIN_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 8ca4511435737e..2eb8ba7a1cbbbb 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -19,7 +19,7 @@ struct StashTorchDispatchModeGuard { } ~StashTorchDispatchModeGuard() { - if (saved_mode_key_ != c10::nullopt) { + if (saved_mode_key_ != std::nullopt) { c10::impl::TorchDispatchModeTLS::set_mode( saved_mode_, saved_mode_key_.value()); } else { diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index e27721c349864c..135c49ac76a925 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -47,7 +47,7 @@ struct arg { // Explicit constructor. explicit arg(std::string name) - : name_(std::move(name)), value_(c10::nullopt) {} + : name_(std::move(name)), value_(std::nullopt) {} // Assignment operator. This enables the pybind-like syntax of // torch::arg("name") = value. arg& operator=(const c10::IValue& rhs) { diff --git a/torch/library.h b/torch/library.h index c860f4c2034444..d75e6b01982120 100644 --- a/torch/library.h +++ b/torch/library.h @@ -215,7 +215,7 @@ class TORCH_API CppFunction final { static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) { return CppFunction( c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)), - /* cpp_signature */ c10::nullopt, // not known for boxed functions + /* cpp_signature */ std::nullopt, // not known for boxed functions /* schema */ nullptr); } @@ -337,7 +337,7 @@ template inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { CppFunction f(std::forward(raw_f)); if (k == c10::DispatchKey::CatchAll) { - f.dispatch_key_ = c10::nullopt; + f.dispatch_key_ = std::nullopt; } else { f.dispatch_key_ = k; } @@ -930,7 +930,7 @@ class TorchLibraryInit final { torch::Library::DEF, \ &TORCH_LIBRARY_init_##ns, \ #ns, \ - c10::nullopt, \ + std::nullopt, \ __FILE__, \ __LINE__); \ void TORCH_LIBRARY_init_##ns(torch::Library& m) @@ -960,7 +960,7 @@ class TorchLibraryInit final { torch::Library::FRAGMENT, \ &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ #ns, \ - c10::nullopt, \ + std::nullopt, \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ @@ -1024,7 +1024,7 @@ class TorchLibraryInit final { ? &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid) \ : [](torch::Library&) -> void {}), \ #ns, \ - c10::make_optional(c10::DispatchKey::k), \ + std::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ @@ -1039,13 +1039,13 @@ class TorchLibraryInit final { /// \private #define MAKE_TORCH_LIBRARY(ns) \ - torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__) + torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__) /// \private #define MAKE_TORCH_LIBRARY_IMPL(ns, k) \ torch::Library( \ torch::Library::IMPL, \ #ns, \ - c10::make_optional(c10::DispatchKey::k), \ + std::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__) From 0492ec460a9fcc58c11f806279f813bd12eacbe0 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 14 Jun 2024 14:03:29 -0700 Subject: [PATCH 035/171] [BE] Remove external testing of torch::deploy (#127952) As we don't expect external users of torch::deploy as the library is no longer supported, we will remove external testing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127952 Approved by: https://github.com/malfet --- .ci/pytorch/common_utils.sh | 22 ---------------------- .ci/pytorch/test.sh | 3 --- .github/workflows/pull.yml | 1 - 3 files changed, 26 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 51297f7bfff886..91c2d1b5dd3bd7 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -188,28 +188,6 @@ function clone_pytorch_xla() { fi } -function checkout_install_torchdeploy() { - local commit - commit=$(get_pinned_commit multipy) - pushd .. - git clone --recurse-submodules https://github.com/pytorch/multipy.git - pushd multipy - git checkout "${commit}" - python multipy/runtime/example/generate_examples.py - BUILD_CUDA_TESTS=1 pip install -e . - popd - popd -} - -function test_torch_deploy(){ - pushd .. - pushd multipy - ./multipy/runtime/build/test_deploy - ./multipy/runtime/build/test_deploy_gpu - popd - popd -} - function checkout_install_torchbench() { local commit commit=$(get_pinned_commit torchbench) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index d708f97704100a..4a38ebefa6cb0f 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1242,9 +1242,6 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then if [[ "${SHARD_NUMBER}" == 1 ]]; then test_rpc fi -elif [[ "$TEST_CONFIG" == deploy ]]; then - checkout_install_torchdeploy - test_torch_deploy elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index b435f1fe0791d9..dc74571852e98c 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -270,7 +270,6 @@ jobs: { config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-test: From 574a2cbcb7365202a860c601ccf2c177110caece Mon Sep 17 00:00:00 2001 From: dilililiwhy Date: Sat, 15 Jun 2024 00:07:40 +0000 Subject: [PATCH 036/171] Enable UFMT on common_device_type.py and common_dtype.py (#128490) Part of: https://github.com/pytorch/pytorch/issues/123062 Ran lintrunner on: > torch/testing/_internal/common_device_type.py > torch/testing/_internal/common_dtype.py Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128490 Approved by: https://github.com/ezyang, https://github.com/XuehaiPan --- .lintrunner.toml | 2 - torch/testing/_internal/common_device_type.py | 643 ++++++++++++------ torch/testing/_internal/common_dtype.py | 81 ++- 3 files changed, 506 insertions(+), 220 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 5e24d8317a0872..07e64fce799d87 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1755,9 +1755,7 @@ exclude_patterns = [ 'torch/testing/_internal/codegen/__init__.py', 'torch/testing/_internal/codegen/random_topo_test.py', 'torch/testing/_internal/common_cuda.py', - 'torch/testing/_internal/common_device_type.py', 'torch/testing/_internal/common_distributed.py', - 'torch/testing/_internal/common_dtype.py', 'torch/testing/_internal/common_jit.py', 'torch/testing/_internal/common_methods_invocations.py', 'torch/testing/_internal/common_modules.py', diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 2e2a379a501e4d..5ab4901155bf7d 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -3,29 +3,55 @@ import copy import gc import inspect +import os import runpy import sys import threading +import unittest from collections import namedtuple from enum import Enum -from functools import wraps, partial -from typing import List, Any, ClassVar, Optional, Sequence, Tuple, Union, Dict, Set -import unittest -import os +from functools import partial, wraps +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, Union + import torch -from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ - skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ - IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, TEST_XPU, TEST_HPU, \ - _TestParametrizer, compose_parametrize_fns, dtype_name, \ - TEST_WITH_MIOPEN_SUGGEST_NHWC, NATIVE_DEVICES, skipIfTorchDynamo, \ - get_tracked_input, clear_tracked_input, PRINT_REPRO_ON_FAILURE, \ - TEST_WITH_TORCHINDUCTOR -from torch.testing._internal.common_cuda import _get_torch_cuda_version, \ - TEST_CUSPARSE_GENERIC, TEST_HIPSPARSE_GENERIC, _get_torch_rocm_version +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + _get_torch_rocm_version, + TEST_CUSPARSE_GENERIC, + TEST_HIPSPARSE_GENERIC, +) from torch.testing._internal.common_dtype import get_all_dtypes +from torch.testing._internal.common_utils import ( + _TestParametrizer, + clear_tracked_input, + compose_parametrize_fns, + dtype_name, + get_tracked_input, + IS_FBCODE, + IS_REMOTE_GPU, + IS_SANDCASTLE, + IS_WINDOWS, + NATIVE_DEVICES, + PRINT_REPRO_ON_FAILURE, + skipCUDANonDefaultStreamIf, + skipIfTorchDynamo, + TEST_HPU, + TEST_MKL, + TEST_MPS, + TEST_WITH_ASAN, + TEST_WITH_MIOPEN_SUGGEST_NHWC, + TEST_WITH_ROCM, + TEST_WITH_TORCHINDUCTOR, + TEST_WITH_TSAN, + TEST_WITH_UBSAN, + TEST_XPU, + TestCase, +) + try: import psutil # type: ignore[import] + HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False @@ -276,21 +302,21 @@ def _dtype_test_suffix(dtypes): - """ Returns the test suffix for a dtype, sequence of dtypes, or None. """ + """Returns the test suffix for a dtype, sequence of dtypes, or None.""" if isinstance(dtypes, (list, tuple)): if len(dtypes) == 0: - return '' - return '_' + '_'.join(dtype_name(d) for d in dtypes) + return "" + return "_" + "_".join(dtype_name(d) for d in dtypes) elif dtypes: - return f'_{dtype_name(dtypes)}' + return f"_{dtype_name(dtypes)}" else: - return '' + return "" def _update_param_kwargs(param_kwargs, name, value): - """ Adds a kwarg with the specified name and value to the param_kwargs dict. """ + """Adds a kwarg with the specified name and value to the param_kwargs dict.""" # Make name plural (e.g. devices / dtypes) if the value is composite. - plural_name = f'{name}s' + plural_name = f"{name}s" # Clear out old entries of the arg if any. if name in param_kwargs: @@ -307,7 +333,7 @@ def _update_param_kwargs(param_kwargs, name, value): class DeviceTypeTestBase(TestCase): - device_type: str = 'generic_device_type' + device_type: str = "generic_device_type" # Flag to disable test suite early due to unrecoverable error such as CUDA error. _stop_test_suite = False @@ -346,7 +372,7 @@ def _init_and_get_primary_device(cls): except Exception: # For CUDATestBase, XLATestBase, and possibly others, the primary device won't be available # until setUpClass() sets it. Call that manually here if needed. - if hasattr(cls, 'setUpClass'): + if hasattr(cls, "setUpClass"): cls.setUpClass() return cls.get_primary_device() @@ -363,28 +389,28 @@ def get_all_devices(cls): # Prefers device-specific dtype specifications over generic ones. @classmethod def _get_dtypes(cls, test): - if not hasattr(test, 'dtypes'): + if not hasattr(test, "dtypes"): return None - default_dtypes = test.dtypes.get('all') + default_dtypes = test.dtypes.get("all") msg = f"@dtypes is mandatory when using @dtypesIf however '{test.__name__}' didn't specify it" assert default_dtypes is not None, msg return test.dtypes.get(cls.device_type, default_dtypes) def _get_precision_override(self, test, dtype): - if not hasattr(test, 'precision_overrides'): + if not hasattr(test, "precision_overrides"): return self.precision return test.precision_overrides.get(dtype, self.precision) def _get_tolerance_override(self, test, dtype): - if not hasattr(test, 'tolerance_overrides'): + if not hasattr(test, "tolerance_overrides"): return self.precision, self.rel_tol return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol)) def _apply_precision_override_for_test(self, test, param_kwargs): - dtype = param_kwargs['dtype'] if 'dtype' in param_kwargs else None - dtype = param_kwargs['dtypes'] if 'dtypes' in param_kwargs else dtype + dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None + dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype if dtype: self.precision = self._get_precision_override(test, dtype) self.precision, self.rel_tol = self._get_tolerance_override(test, dtype) @@ -392,16 +418,17 @@ def _apply_precision_override_for_test(self, test, param_kwargs): # Creates device-specific tests. @classmethod def instantiate_test(cls, name, test, *, generic_cls=None): - - def instantiate_test_helper(cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: []): + def instantiate_test_helper( + cls, name, *, test, param_kwargs=None, decorator_fn=lambda _: [] + ): # Add the device param kwarg if the test needs device or devices. param_kwargs = {} if param_kwargs is None else param_kwargs test_sig_params = inspect.signature(test).parameters - if 'device' in test_sig_params or 'devices' in test_sig_params: + if "device" in test_sig_params or "devices" in test_sig_params: device_arg: str = cls._init_and_get_primary_device() - if hasattr(test, 'num_required_devices'): + if hasattr(test, "num_required_devices"): device_arg = cls.get_all_devices() - _update_param_kwargs(param_kwargs, 'device', device_arg) + _update_param_kwargs(param_kwargs, "device", device_arg) # Apply decorators based on param kwargs. for decorator in decorator_fn(param_kwargs): @@ -424,9 +451,16 @@ def instantiated_test(self, param_kwargs=param_kwargs): # Using `__unittest_expecting_failure__` attribute, see # https://github.com/python/cpython/blob/ffa505b580464/Lib/unittest/case.py#L164 # In that case, make it fail with "unexpected success" by suppressing exception - if getattr(test, "__unittest_expecting_failure__", False) and self._stop_test_suite: + if ( + getattr(test, "__unittest_expecting_failure__", False) + and self._stop_test_suite + ): import sys - print("Suppressing fatal exception to trigger unexpected success", file=sys.stderr) + + print( + "Suppressing fatal exception to trigger unexpected success", + file=sys.stderr, + ) return # raise the runtime error as is for the test suite to record. raise rte @@ -441,7 +475,7 @@ def instantiated_test(self, param_kwargs=param_kwargs): def default_parametrize_fn(test, generic_cls, device_cls): # By default, no parametrization is needed. - yield (test, '', {}, lambda _: []) + yield (test, "", {}, lambda _: []) # Parametrization decorators set the parametrize_fn attribute on the test. parametrize_fn = getattr(test, "parametrize_fn", default_parametrize_fn) @@ -457,24 +491,42 @@ def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes): # Note that an empty test suffix is set here so that the dtype can be appended # later after the device. - yield (test, '', param_kwargs, lambda _: []) + yield (test, "", param_kwargs, lambda _: []) - parametrize_fn = compose_parametrize_fns(dtype_parametrize_fn, parametrize_fn) + parametrize_fn = compose_parametrize_fns( + dtype_parametrize_fn, parametrize_fn + ) # Instantiate the parametrized tests. - for (test, test_suffix, param_kwargs, decorator_fn) in parametrize_fn(test, generic_cls, cls): # noqa: B020 - test_suffix = '' if test_suffix == '' else '_' + test_suffix - device_suffix = '_' + cls.device_type + for ( + test, # noqa: B020 + test_suffix, + param_kwargs, + decorator_fn, + ) in parametrize_fn(test, generic_cls, cls): + test_suffix = "" if test_suffix == "" else "_" + test_suffix + device_suffix = "_" + cls.device_type # Note: device and dtype suffix placement # Special handling here to place dtype(s) after device according to test name convention. dtype_kwarg = None - if 'dtype' in param_kwargs or 'dtypes' in param_kwargs: - dtype_kwarg = param_kwargs['dtypes'] if 'dtypes' in param_kwargs else param_kwargs['dtype'] - test_name = f'{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}' - - instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs, - decorator_fn=decorator_fn) + if "dtype" in param_kwargs or "dtypes" in param_kwargs: + dtype_kwarg = ( + param_kwargs["dtypes"] + if "dtypes" in param_kwargs + else param_kwargs["dtype"] + ) + test_name = ( + f"{name}{test_suffix}{device_suffix}{_dtype_test_suffix(dtype_kwarg)}" + ) + + instantiate_test_helper( + cls=cls, + name=test_name, + test=test, + param_kwargs=param_kwargs, + decorator_fn=decorator_fn, + ) def run(self, result=None): super().run(result=result) @@ -484,14 +536,15 @@ def run(self, result=None): class CPUTestBase(DeviceTypeTestBase): - device_type = 'cpu' + device_type = "cpu" # No critical error should stop CPU test suite def _should_stop_test_suite(self): return False + class CUDATestBase(DeviceTypeTestBase): - device_type = 'cuda' + device_type = "cuda" _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True primary_device: ClassVar[str] @@ -508,12 +561,16 @@ def get_primary_device(cls): @classmethod def get_all_devices(cls): - primary_device_idx = int(cls.get_primary_device().split(':')[1]) + primary_device_idx = int(cls.get_primary_device().split(":")[1]) num_devices = torch.cuda.device_count() prim_device = cls.get_primary_device() - cuda_str = 'cuda:{0}' - non_primary_devices = [cuda_str.format(idx) for idx in range(num_devices) if idx != primary_device_idx] + cuda_str = "cuda:{0}" + non_primary_devices = [ + cuda_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] return [prim_device] + non_primary_devices @classmethod @@ -527,12 +584,15 @@ def setUpClass(cls): cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() # Acquires the current device as the primary (test) device - cls.primary_device = f'cuda:{torch.cuda.current_device()}' + cls.primary_device = f"cuda:{torch.cuda.current_device()}" + # See Note [Lazy Tensor tests in device agnostic testing] lazy_ts_backend_init = False + + class LazyTestBase(DeviceTypeTestBase): - device_type = 'lazy' + device_type = "lazy" def _should_stop_test_suite(self): return False @@ -542,14 +602,16 @@ def setUpClass(cls): import torch._lazy import torch._lazy.metrics import torch._lazy.ts_backend + global lazy_ts_backend_init if not lazy_ts_backend_init: # Need to connect the TS backend to lazy key before running tests torch._lazy.ts_backend.init() lazy_ts_backend_init = True + class MPSTestBase(DeviceTypeTestBase): - device_type = 'mps' + device_type = "mps" primary_device: ClassVar[str] @classmethod @@ -564,13 +626,14 @@ def get_all_devices(cls): @classmethod def setUpClass(cls): - cls.primary_device = 'mps:0' + cls.primary_device = "mps:0" def _should_stop_test_suite(self): return False + class XPUTestBase(DeviceTypeTestBase): - device_type = 'xpu' + device_type = "xpu" primary_device: ClassVar[str] @classmethod @@ -585,13 +648,14 @@ def get_all_devices(cls): @classmethod def setUpClass(cls): - cls.primary_device = 'xpu:0' + cls.primary_device = "xpu:0" def _should_stop_test_suite(self): return False + class HPUTestBase(DeviceTypeTestBase): - device_type = 'hpu' + device_type = "hpu" primary_device: ClassVar[str] @classmethod @@ -600,12 +664,13 @@ def get_primary_device(cls): @classmethod def setUpClass(cls): - cls.primary_device = 'hpu:0' + cls.primary_device = "hpu:0" + class PrivateUse1TestBase(DeviceTypeTestBase): primary_device: ClassVar[str] device_mod = None - device_type = 'privateuse1' + device_type = "privateuse1" @classmethod def get_primary_device(cls): @@ -613,20 +678,27 @@ def get_primary_device(cls): @classmethod def get_all_devices(cls): - primary_device_idx = int(cls.get_primary_device().split(':')[1]) + primary_device_idx = int(cls.get_primary_device().split(":")[1]) num_devices = cls.device_mod.device_count() prim_device = cls.get_primary_device() - device_str = f'{cls.device_type}:{{0}}' - non_primary_devices = [device_str.format(idx) for idx in range(num_devices) if idx != primary_device_idx] + device_str = f"{cls.device_type}:{{0}}" + non_primary_devices = [ + device_str.format(idx) + for idx in range(num_devices) + if idx != primary_device_idx + ] return [prim_device] + non_primary_devices @classmethod def setUpClass(cls): cls.device_type = torch._C._get_privateuse1_backend_name() cls.device_mod = getattr(torch, cls.device_type, None) - assert cls.device_mod is not None, f'''torch has no module of `{cls.device_type}`, you should register - a module by `torch._register_device_module`.''' - cls.primary_device = f'{cls.device_type}:{cls.device_mod.current_device()}' + assert ( + cls.device_mod is not None + ), f"""torch has no module of `{cls.device_type}`, you should register + a module by `torch._register_device_module`.""" + cls.primary_device = f"{cls.device_type}:{cls.device_mod.current_device()}" + # Adds available device-type-specific test base classes def get_device_type_test_bases(): @@ -657,20 +729,27 @@ def get_device_type_test_bases(): return test_bases + device_type_test_bases = get_device_type_test_bases() def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None): # device type cannot appear in both except_for and only_for - intersect = set(except_for if except_for else []) & set(only_for if only_for else []) - assert not intersect, f"device ({intersect}) appeared in both except_for and only_for" + intersect = set(except_for if except_for else []) & set( + only_for if only_for else [] + ) + assert ( + not intersect + ), f"device ({intersect}) appeared in both except_for and only_for" if except_for: device_type_test_bases = filter( - lambda x: x.device_type not in except_for, device_type_test_bases) + lambda x: x.device_type not in except_for, device_type_test_bases + ) if only_for: device_type_test_bases = filter( - lambda x: x.device_type in only_for, device_type_test_bases) + lambda x: x.device_type in only_for, device_type_test_bases + ) return list(device_type_test_bases) @@ -691,32 +770,36 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo # - To run tests with new device type, set `TORCH_TEST_DEVICE` env variable to path # to this file. Multiple paths can be separated by `:`. # See pytorch/xla/test/pytorch_test_base.py for a more detailed example. -_TORCH_TEST_DEVICES = os.environ.get('TORCH_TEST_DEVICES', None) +_TORCH_TEST_DEVICES = os.environ.get("TORCH_TEST_DEVICES", None) if _TORCH_TEST_DEVICES: - for path in _TORCH_TEST_DEVICES.split(':'): + for path in _TORCH_TEST_DEVICES.split(":"): # runpy (a stdlib module) lacks annotations mod = runpy.run_path(path, init_globals=globals()) # type: ignore[func-returns-value] - device_type_test_bases.append(mod['TEST_CLASS']) + device_type_test_bases.append(mod["TEST_CLASS"]) -PYTORCH_CUDA_MEMCHECK = os.getenv('PYTORCH_CUDA_MEMCHECK', '0') == '1' +PYTORCH_CUDA_MEMCHECK = os.getenv("PYTORCH_CUDA_MEMCHECK", "0") == "1" -PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = 'PYTORCH_TESTING_DEVICE_ONLY_FOR' -PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = 'PYTORCH_TESTING_DEVICE_EXCEPT_FOR' -PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = 'PYTORCH_TESTING_DEVICE_FOR_CUSTOM' +PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY = "PYTORCH_TESTING_DEVICE_ONLY_FOR" +PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY = "PYTORCH_TESTING_DEVICE_EXCEPT_FOR" +PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY = "PYTORCH_TESTING_DEVICE_FOR_CUSTOM" -def get_desired_device_type_test_bases(except_for=None, only_for=None, include_lazy=False, allow_mps=False): +def get_desired_device_type_test_bases( + except_for=None, only_for=None, include_lazy=False, allow_mps=False +): # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` test_bases = device_type_test_bases.copy() if allow_mps and TEST_MPS and MPSTestBase not in test_bases: test_bases.append(MPSTestBase) - if only_for == 'xpu' and TEST_XPU and XPUTestBase not in test_bases: + if only_for == "xpu" and TEST_XPU and XPUTestBase not in test_bases: test_bases.append(XPUTestBase) if TEST_HPU and HPUTestBase not in test_bases: test_bases.append(HPUTestBase) # Filter out the device types based on user inputs - desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for) + desired_device_type_test_bases = filter_desired_device_types( + test_bases, except_for, only_for + ) if include_lazy: # Note [Lazy Tensor tests in device agnostic testing] # Right now, test_view_ops.py runs with LazyTensor. @@ -725,7 +808,10 @@ def get_desired_device_type_test_bases(except_for=None, only_for=None, include_l # So instead, the only way to opt a specific device-agnostic test file into # lazy tensor testing is with include_lazy=True if IS_FBCODE: - print("TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", file=sys.stderr) + print( + "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds", + file=sys.stderr, + ) else: desired_device_type_test_bases.append(LazyTestBase) @@ -735,20 +821,29 @@ def split_if_not_empty(x: str): # run some cuda testcases on other devices if available # Usage: # export PYTORCH_TESTING_DEVICE_FOR_CUSTOM=privateuse1 - env_custom_only_for = split_if_not_empty(os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, '')) + env_custom_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_FOR_CUSTOM_KEY, "") + ) if env_custom_only_for: - desired_device_type_test_bases += filter(lambda x: x.device_type in env_custom_only_for, test_bases) + desired_device_type_test_bases += filter( + lambda x: x.device_type in env_custom_only_for, test_bases + ) desired_device_type_test_bases = list(set(desired_device_type_test_bases)) # Filter out the device types based on environment variables if available # Usage: # export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu # export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla - env_only_for = split_if_not_empty(os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, '')) - env_except_for = split_if_not_empty(os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, '')) - - return filter_desired_device_types(desired_device_type_test_bases, env_except_for, env_only_for) + env_only_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, "") + ) + env_except_for = split_if_not_empty( + os.getenv(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, "") + ) + return filter_desired_device_types( + desired_device_type_test_bases, env_except_for, env_only_for + ) # Adds 'instantiated' device-specific test cases to the given scope. @@ -758,7 +853,14 @@ def split_if_not_empty(x: str): # device-specific tests (NB: this supports additional @parametrize usage). # # See note "Writing Test Templates" -def instantiate_device_type_tests(generic_test_class, scope, except_for=None, only_for=None, include_lazy=False, allow_mps=False): +def instantiate_device_type_tests( + generic_test_class, + scope, + except_for=None, + only_for=None, + include_lazy=False, + allow_mps=False, +): # Removes the generic test class from its enclosing scope so its tests # are not discoverable. del scope[generic_test_class.__name__] @@ -774,11 +876,15 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on # Acquires members names # See Note [Overriding methods in generic tests] - generic_members = set(generic_test_class.__dict__.keys()) - set(empty_class.__dict__.keys()) - generic_tests = [x for x in generic_members if x.startswith('test')] + generic_members = set(generic_test_class.__dict__.keys()) - set( + empty_class.__dict__.keys() + ) + generic_tests = [x for x in generic_members if x.startswith("test")] # Creates device-specific test cases - for base in get_desired_device_type_test_bases(except_for, only_for, include_lazy, allow_mps): + for base in get_desired_device_type_test_bases( + except_for, only_for, include_lazy, allow_mps + ): class_name = generic_test_class.__name__ + base.device_type.upper() # type set to Any and suppressed due to unsupport runtime class: @@ -792,11 +898,15 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on sig = inspect.signature(device_type_test_class.instantiate_test) if len(sig.parameters) == 3: # Instantiates the device-specific tests - device_type_test_class.instantiate_test(name, copy.deepcopy(test), generic_cls=generic_test_class) + device_type_test_class.instantiate_test( + name, copy.deepcopy(test), generic_cls=generic_test_class + ) else: device_type_test_class.instantiate_test(name, copy.deepcopy(test)) else: # Ports non-test member - assert name not in device_type_test_class.__dict__, f"Redefinition of directly defined member {name}" + assert ( + name not in device_type_test_class.__dict__ + ), f"Redefinition of directly defined member {name}" nontest = getattr(generic_test_class, name) setattr(device_type_test_class, name, nontest) @@ -846,7 +956,9 @@ class OpDTypes(Enum): unsupported_backward = 3 # Test only unsupported backward dtypes any_one = 4 # Test precisely one supported dtype none = 5 # Instantiate no dtype variants (no dtype kwarg needed) - any_common_cpu_cuda_one = 6 # Test precisely one supported dtype that is common to both cuda and cpu + any_common_cpu_cuda_one = ( + 6 # Test precisely one supported dtype that is common to both cuda and cpu + ) # Arbitrary order @@ -862,15 +974,17 @@ class OpDTypes(Enum): torch.int16, torch.int8, torch.uint8, - torch.bool + torch.bool, ) + def _serialize_sample(sample_input): # NB: For OpInfos, SampleInput.summary() prints in a cleaner way. if getattr(sample_input, "summary", None) is not None: return sample_input.summary() return str(sample_input) + # Decorator that defines the OpInfos a test template should be instantiated for. # # Example usage: @@ -908,20 +1022,31 @@ def _serialize_sample(sample_input): # These options allow tests to have considerable control over the dtypes # they're instantiated for. + class ops(_TestParametrizer): - def __init__(self, op_list, *, dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, - allowed_dtypes: Optional[Sequence[torch.dtype]] = None, skip_if_dynamo=True): + def __init__( + self, + op_list, + *, + dtypes: Union[OpDTypes, Sequence[torch.dtype]] = OpDTypes.supported, + allowed_dtypes: Optional[Sequence[torch.dtype]] = None, + skip_if_dynamo=True, + ): self.op_list = list(op_list) self.opinfo_dtypes = dtypes - self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None + self.allowed_dtypes = ( + set(allowed_dtypes) if allowed_dtypes is not None else None + ) self.skip_if_dynamo = skip_if_dynamo def _parametrize_test(self, test, generic_cls, device_cls): - """ Parameterizes the given test function across each op and its associated dtypes. """ + """Parameterizes the given test function across each op and its associated dtypes.""" if device_cls is None: - raise RuntimeError('The @ops decorator is only intended to be used in a device-specific ' - 'context; use it with instantiate_device_type_tests() instead of ' - 'instantiate_parametrized_tests()') + raise RuntimeError( + "The @ops decorator is only intended to be used in a device-specific " + "context; use it with instantiate_device_type_tests() instead of " + "instantiate_parametrized_tests()" + ) op = check_exhausted_iterator = object() for op in self.op_list: @@ -930,17 +1055,23 @@ def _parametrize_test(self, test, generic_cls, device_cls): if isinstance(self.opinfo_dtypes, Sequence): dtypes = set(self.opinfo_dtypes) elif self.opinfo_dtypes == OpDTypes.unsupported_backward: - dtypes = set(get_all_dtypes()).difference(op.supported_backward_dtypes(device_cls.device_type)) + dtypes = set(get_all_dtypes()).difference( + op.supported_backward_dtypes(device_cls.device_type) + ) elif self.opinfo_dtypes == OpDTypes.supported_backward: dtypes = op.supported_backward_dtypes(device_cls.device_type) elif self.opinfo_dtypes == OpDTypes.unsupported: - dtypes = set(get_all_dtypes()).difference(op.supported_dtypes(device_cls.device_type)) + dtypes = set(get_all_dtypes()).difference( + op.supported_dtypes(device_cls.device_type) + ) elif self.opinfo_dtypes == OpDTypes.supported: dtypes = set(op.supported_dtypes(device_cls.device_type)) elif self.opinfo_dtypes == OpDTypes.any_one: # Tries to pick a dtype that supports both forward or backward supported = op.supported_dtypes(device_cls.device_type) - supported_backward = op.supported_backward_dtypes(device_cls.device_type) + supported_backward = op.supported_backward_dtypes( + device_cls.device_type + ) supported_both = supported.intersection(supported_backward) dtype_set = supported_both if len(supported_both) > 0 else supported for dtype in ANY_DTYPE_ORDER: @@ -953,7 +1084,9 @@ def _parametrize_test(self, test, generic_cls, device_cls): # Tries to pick a dtype that supports both CPU and CUDA supported = set(op.dtypes).intersection(op.dtypesIfCUDA) if supported: - dtypes = {next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported)} + dtypes = { + next(dtype for dtype in ANY_DTYPE_ORDER if dtype in supported) + } else: dtypes = {} @@ -971,14 +1104,15 @@ def _parametrize_test(self, test, generic_cls, device_cls): for dtype in dtypes: # Construct parameter kwargs to pass to the test. - param_kwargs = {'op': op} - _update_param_kwargs(param_kwargs, 'dtype', dtype) + param_kwargs = {"op": op} + _update_param_kwargs(param_kwargs, "dtype", dtype) # NOTE: test_wrapper exists because we don't want to apply # op-specific decorators to the original test. # Test-specific decorators are applied to the original test, # however. try: + @wraps(test) def test_wrapper(*args, **kwargs): try: @@ -991,21 +1125,29 @@ def test_wrapper(*args, **kwargs): raise Exception( # noqa: TRY002 f"Caused by {tracked_input.type_desc} " f"at index {tracked_input.index}: " - f"{_serialize_sample(tracked_input.val)}") from e + f"{_serialize_sample(tracked_input.val)}" + ) from e raise e finally: clear_tracked_input() if self.skip_if_dynamo and not TEST_WITH_TORCHINDUCTOR: - test_wrapper = skipIfTorchDynamo("Policy: we don't run OpInfo tests w/ Dynamo")(test_wrapper) + test_wrapper = skipIfTorchDynamo( + "Policy: we don't run OpInfo tests w/ Dynamo" + )(test_wrapper) # Initialize info for the last input seen. This is useful for tracking # down which inputs caused a test failure. Note that TrackedInputIter is # responsible for managing this. test.tracked_input = None - decorator_fn = partial(op.get_decorators, generic_cls.__name__, - test.__name__, device_cls.device_type, dtype) + decorator_fn = partial( + op.get_decorators, + generic_cls.__name__, + test.__name__, + device_cls.device_type, + dtype, + ) yield (test_wrapper, test_name, param_kwargs, decorator_fn) except Exception as ex: @@ -1013,8 +1155,11 @@ def test_wrapper(*args, **kwargs): print(f"Failed to instantiate {test_name} for op {op.name}!") raise ex if op is check_exhausted_iterator: - raise ValueError('An empty op_list was passed to @ops. ' - 'Note that this may result from reuse of a generator.') + raise ValueError( + "An empty op_list was passed to @ops. " + "Note that this may result from reuse of a generator." + ) + # Decorator that skips a test if the given condition is true. # Notes: @@ -1025,91 +1170,92 @@ def test_wrapper(*args, **kwargs): # probably define a new decorator instead (see below). # (3) Prefer the existing decorators to defining the 'device_type' kwarg. class skipIf: - def __init__(self, dep, reason, device_type=None): self.dep = dep self.reason = reason self.device_type = device_type def __call__(self, fn): - @wraps(fn) def dep_fn(slf, *args, **kwargs): if self.device_type is None or self.device_type == slf.device_type: - if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (isinstance(self.dep, bool) and self.dep): + if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( + isinstance(self.dep, bool) and self.dep + ): raise unittest.SkipTest(self.reason) return fn(slf, *args, **kwargs) + return dep_fn # Skips a test on CPU if the condition is true. class skipCPUIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='cpu') + super().__init__(dep, reason, device_type="cpu") # Skips a test on CUDA if the condition is true. class skipCUDAIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='cuda') + super().__init__(dep, reason, device_type="cuda") + # Skips a test on Lazy if the condition is true. class skipLazyIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='lazy') + super().__init__(dep, reason, device_type="lazy") + # Skips a test on Meta if the condition is true. class skipMetaIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='meta') + super().__init__(dep, reason, device_type="meta") + # Skips a test on MPS if the condition is true. class skipMPSIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='mps') + super().__init__(dep, reason, device_type="mps") + class skipHPUIf(skipIf): def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='hpu') + super().__init__(dep, reason, device_type="hpu") + # Skips a test on XLA if the condition is true. class skipXLAIf(skipIf): - def __init__(self, dep, reason): - super().__init__(dep, reason, device_type='xla') + super().__init__(dep, reason, device_type="xla") -class skipPRIVATEUSE1If(skipIf): +class skipPRIVATEUSE1If(skipIf): def __init__(self, dep, reason): device_type = torch._C._get_privateuse1_backend_name() super().__init__(dep, reason, device_type=device_type) + def _has_sufficient_memory(device, size): - if torch.device(device).type == 'cuda': + if torch.device(device).type == "cuda": if not torch.cuda.is_available(): return False gc.collect() torch.cuda.empty_cache() # torch.cuda.mem_get_info, aka cudaMemGetInfo, returns a tuple of (free memory, total memory) of a GPU - if device == 'cuda': - device = 'cuda:0' + if device == "cuda": + device = "cuda:0" return torch.cuda.memory.mem_get_info(device)[0] >= size - if device == 'xla': - raise unittest.SkipTest('TODO: Memory availability checks for XLA?') + if device == "xla": + raise unittest.SkipTest("TODO: Memory availability checks for XLA?") - if device != 'cpu': - raise unittest.SkipTest('Unknown device type') + if device != "cpu": + raise unittest.SkipTest("Unknown device type") # CPU if not HAS_PSUTIL: - raise unittest.SkipTest('Need psutil to determine if memory is sufficient') + raise unittest.SkipTest("Need psutil to determine if memory is sufficient") # The sanitizers have significant memory overheads if TEST_WITH_ASAN or TEST_WITH_TSAN or TEST_WITH_UBSAN: @@ -1132,8 +1278,8 @@ def largeTensorTest(size, device=None): In other tests, the `device=` argument needs to be specified. """ if isinstance(size, str): - assert size.endswith(('GB', 'gb')), "only bytes or GB supported" - size = 1024 ** 3 * int(size[:-2]) + assert size.endswith(("GB", "gb")), "only bytes or GB supported" + size = 1024**3 * int(size[:-2]) def inner(fn): @wraps(fn) @@ -1141,23 +1287,27 @@ def dep_fn(self, *args, **kwargs): size_bytes = size(self, *args, **kwargs) if callable(size) else size _device = device if device is not None else self.get_primary_device() if not _has_sufficient_memory(_device, size_bytes): - raise unittest.SkipTest(f'Insufficient {_device} memory') + raise unittest.SkipTest(f"Insufficient {_device} memory") return fn(self, *args, **kwargs) + return dep_fn + return inner class expectedFailure: - def __init__(self, device_type): self.device_type = device_type def __call__(self, fn): - @wraps(fn) def efail_fn(slf, *args, **kwargs): - if not hasattr(slf, "device_type") and hasattr(slf, "device") and isinstance(slf.device, str): + if ( + not hasattr(slf, "device_type") + and hasattr(slf, "device") + and isinstance(slf.device, str) + ): target_device_type = slf.device else: target_device_type = slf.device_type @@ -1168,19 +1318,18 @@ def efail_fn(slf, *args, **kwargs): except Exception: return else: - slf.fail('expected test to fail, but it passed') + slf.fail("expected test to fail, but it passed") return fn(slf, *args, **kwargs) + return efail_fn class onlyOn: - def __init__(self, device_type): self.device_type = device_type def __call__(self, fn): - @wraps(fn) def only_fn(slf, *args, **kwargs): if self.device_type != slf.device_type: @@ -1197,12 +1346,13 @@ def only_fn(slf, *args, **kwargs): # Skips the test if the number of available devices of the variant's device # type is less than the 'num_required_devices' arg. class deviceCountAtLeast: - def __init__(self, num_required_devices): self.num_required_devices = num_required_devices def __call__(self, fn): - assert not hasattr(fn, 'num_required_devices'), f"deviceCountAtLeast redefinition for {fn.__name__}" + assert not hasattr( + fn, "num_required_devices" + ), f"deviceCountAtLeast redefinition for {fn.__name__}" fn.num_required_devices = self.num_required_devices @wraps(fn) @@ -1215,6 +1365,7 @@ def multi_fn(slf, devices, *args, **kwargs): return multi_fn + # Only runs the test on the native device type (currently CPU, CUDA, Meta and PRIVATEUSE1) def onlyNativeDeviceTypes(fn): @wraps(fn) @@ -1227,6 +1378,7 @@ def only_fn(self, *args, **kwargs): return only_fn + # Specifies per-dtype precision overrides. # Ex. # @@ -1245,11 +1397,14 @@ def only_fn(self, *args, **kwargs): # explicitly and computed using self.precision (e.g. # self.precision *2, max(1, self.precision)). class precisionOverride: - def __init__(self, d): - assert isinstance(d, dict), "precisionOverride not given a dtype : precision dict!" + assert isinstance( + d, dict + ), "precisionOverride not given a dtype : precision dict!" for dtype in d.keys(): - assert isinstance(dtype, torch.dtype), f"precisionOverride given unknown dtype {dtype}" + assert isinstance( + dtype, torch.dtype + ), f"precisionOverride given unknown dtype {dtype}" self.d = d @@ -1257,6 +1412,7 @@ def __call__(self, fn): fn.precision_overrides = self.d return fn + # Specifies per-dtype tolerance overrides tol(atol, rtol). It has priority over # precisionOverride. # Ex. @@ -1274,14 +1430,19 @@ def __call__(self, fn): # # The above example sets atol = 1e-2 and rtol = 1e-3 for torch.float and # atol = 1e-4 and rtol = 0 for torch.double. -tol = namedtuple('tol', ['atol', 'rtol']) +tol = namedtuple("tol", ["atol", "rtol"]) + class toleranceOverride: def __init__(self, d): assert isinstance(d, dict), "toleranceOverride not given a dtype : tol dict!" for dtype, prec in d.items(): - assert isinstance(dtype, torch.dtype), f"toleranceOverride given unknown dtype {dtype}" - assert isinstance(prec, tol), "toleranceOverride not given a dtype : tol dict!" + assert isinstance( + dtype, torch.dtype + ), f"toleranceOverride given unknown dtype {dtype}" + assert isinstance( + prec, tol + ), "toleranceOverride not given a dtype : tol dict!" self.d = d @@ -1289,6 +1450,7 @@ def __call__(self, fn): fn.tolerance_overrides = self.d return fn + # Decorator that instantiates a variant of the test for each given dtype. # Notes: # (1) Tests that accept the dtype argument MUST use this decorator. @@ -1300,23 +1462,27 @@ def __call__(self, fn): # @dtypes(torch.float32, torch.float64) # @dtypes((torch.long, torch.float32), (torch.int, torch.float64)) class dtypes: - def __init__(self, *args, device_type="all"): if len(args) > 0 and isinstance(args[0], (list, tuple)): for arg in args: - assert isinstance(arg, (list, tuple)), \ - "When one dtype variant is a tuple or list, " \ - "all dtype variants must be. " \ + assert isinstance(arg, (list, tuple)), ( + "When one dtype variant is a tuple or list, " + "all dtype variants must be. " f"Received non-list non-tuple dtype {str(arg)}" - assert all(isinstance(dtype, torch.dtype) for dtype in arg), f"Unknown dtype in {str(arg)}" + ) + assert all( + isinstance(dtype, torch.dtype) for dtype in arg + ), f"Unknown dtype in {str(arg)}" else: - assert all(isinstance(arg, torch.dtype) for arg in args), f"Unknown dtype in {str(args)}" + assert all( + isinstance(arg, torch.dtype) for arg in args + ), f"Unknown dtype in {str(args)}" self.args = args self.device_type = device_type def __call__(self, fn): - d = getattr(fn, 'dtypes', {}) + d = getattr(fn, "dtypes", {}) assert self.device_type not in d, f"dtypes redefinition for {self.device_type}" d[self.device_type] = self.args fn.dtypes = d @@ -1325,44 +1491,45 @@ def __call__(self, fn): # Overrides specified dtypes on the CPU. class dtypesIfCPU(dtypes): - def __init__(self, *args): - super().__init__(*args, device_type='cpu') + super().__init__(*args, device_type="cpu") # Overrides specified dtypes on CUDA. class dtypesIfCUDA(dtypes): - def __init__(self, *args): - super().__init__(*args, device_type='cuda') + super().__init__(*args, device_type="cuda") -class dtypesIfMPS(dtypes): +class dtypesIfMPS(dtypes): def __init__(self, *args): - super().__init__(*args, device_type='mps') + super().__init__(*args, device_type="mps") -class dtypesIfPRIVATEUSE1(dtypes): +class dtypesIfPRIVATEUSE1(dtypes): def __init__(self, *args): super().__init__(*args, device_type=torch._C._get_privateuse1_backend_name()) + def onlyCPU(fn): - return onlyOn('cpu')(fn) + return onlyOn("cpu")(fn) def onlyCUDA(fn): - return onlyOn('cuda')(fn) + return onlyOn("cuda")(fn) def onlyMPS(fn): - return onlyOn('mps')(fn) + return onlyOn("mps")(fn) def onlyXPU(fn): - return onlyOn('xpu')(fn) + return onlyOn("xpu")(fn) + def onlyHPU(fn): - return onlyOn('hpu')(fn) + return onlyOn("hpu")(fn) + def onlyPRIVATEUSE1(fn): device_type = torch._C._get_privateuse1_backend_name() @@ -1372,10 +1539,11 @@ def onlyPRIVATEUSE1(fn): return unittest.skip(reason)(fn) return onlyOn(device_type)(fn) + def onlyCUDAAndPRIVATEUSE1(fn): @wraps(fn) def only_fn(self, *args, **kwargs): - if self.device_type not in ('cuda', torch._C._get_privateuse1_backend_name()): + if self.device_type not in ("cuda", torch._C._get_privateuse1_backend_name()): reason = f"onlyCUDAAndPRIVATEUSE1: doesn't run on {self.device_type}" raise unittest.SkipTest(reason) @@ -1383,19 +1551,19 @@ def only_fn(self, *args, **kwargs): return only_fn -def disablecuDNN(fn): +def disablecuDNN(fn): @wraps(fn) def disable_cudnn(self, *args, **kwargs): - if self.device_type == 'cuda' and self.has_cudnn(): + if self.device_type == "cuda" and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): return fn(self, *args, **kwargs) return fn(self, *args, **kwargs) return disable_cudnn -def disableMkldnn(fn): +def disableMkldnn(fn): @wraps(fn) def disable_mkldnn(self, *args, **kwargs): if torch.backends.mkldnn.is_available(): @@ -1407,23 +1575,28 @@ def disable_mkldnn(self, *args, **kwargs): def expectedFailureCPU(fn): - return expectedFailure('cpu')(fn) + return expectedFailure("cpu")(fn) def expectedFailureCUDA(fn): - return expectedFailure('cuda')(fn) + return expectedFailure("cuda")(fn) + def expectedFailureXPU(fn): - return expectedFailure('xpu')(fn) + return expectedFailure("xpu")(fn) + def expectedFailureMeta(fn): - return skipIfTorchDynamo()(expectedFailure('meta')(fn)) + return skipIfTorchDynamo()(expectedFailure("meta")(fn)) + def expectedFailureXLA(fn): - return expectedFailure('xla')(fn) + return expectedFailure("xla")(fn) + def expectedFailureHPU(fn): - return expectedFailure('hpu')(fn) + return expectedFailure("hpu")(fn) + # Skips a test on CPU if LAPACK is not available. def skipCPUIfNoLapack(fn): @@ -1432,7 +1605,9 @@ def skipCPUIfNoLapack(fn): # Skips a test on CPU if FFT is not available. def skipCPUIfNoFFT(fn): - return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")(fn) + return skipCPUIf(not torch._C.has_spectral, "PyTorch is built without FFT support")( + fn + ) # Skips a test on CPU if MKL is not available. @@ -1442,29 +1617,41 @@ def skipCPUIfNoMkl(fn): # Skips a test on CPU if MKL Sparse is not available (it's not linked on Windows). def skipCPUIfNoMklSparse(fn): - return skipCPUIf(IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support")(fn) + return skipCPUIf( + IS_WINDOWS or not TEST_MKL, "PyTorch is built without MKL support" + )(fn) # Skips a test on CPU if mkldnn is not available. def skipCPUIfNoMkldnn(fn): - return skipCPUIf(not torch.backends.mkldnn.is_available(), "PyTorch is built without mkldnn support")(fn) + return skipCPUIf( + not torch.backends.mkldnn.is_available(), + "PyTorch is built without mkldnn support", + )(fn) # Skips a test on CUDA if MAGMA is not available. def skipCUDAIfNoMagma(fn): - return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn)) + return skipCUDAIf("no_magma", "no MAGMA library detected")( + skipCUDANonDefaultStreamIf(True)(fn) + ) + def has_cusolver(): return not TEST_WITH_ROCM + def has_hipsolver(): rocm_version = _get_torch_rocm_version() # hipSOLVER is disabled on ROCM < 5.3 return rocm_version >= (5, 3) + # Skips a test on CUDA/ROCM if cuSOLVER/hipSOLVER is not available def skipCUDAIfNoCusolver(fn): - return skipCUDAIf(not has_cusolver() and not has_hipsolver(), "cuSOLVER not available")(fn) + return skipCUDAIf( + not has_cusolver() and not has_hipsolver(), "cuSOLVER not available" + )(fn) # Skips a test if both cuSOLVER and MAGMA are not available @@ -1475,6 +1662,7 @@ def skipCUDAIfNoMagmaAndNoCusolver(fn): # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA return skipCUDAIfNoMagma(fn) + # Skips a test if both cuSOLVER/hipSOLVER and MAGMA are not available def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): if has_cusolver() or has_hipsolver(): @@ -1483,45 +1671,62 @@ def skipCUDAIfNoMagmaAndNoLinalgsolver(fn): # cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA return skipCUDAIfNoMagma(fn) + # Skips a test on CUDA when using ROCm. def skipCUDAIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"): def dec_fn(fn): reason = f"skipCUDAIfRocm: {msg}" return skipCUDAIf(TEST_WITH_ROCM, reason=reason)(fn) + if func: return dec_fn(func) return dec_fn + # Skips a test on CUDA when not using ROCm. def skipCUDAIfNotRocm(fn): - return skipCUDAIf(not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack")(fn) + return skipCUDAIf( + not TEST_WITH_ROCM, "test doesn't currently work on the CUDA stack" + )(fn) + # Skips a test on CUDA if ROCm is unavailable or its version is lower than requested. def skipCUDAIfRocmVersionLessThan(version=None): - def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): - if self.device_type == 'cuda': + if self.device_type == "cuda": if not TEST_WITH_ROCM: reason = "ROCm not available" raise unittest.SkipTest(reason) rocm_version_tuple = _get_torch_rocm_version() - if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version): - reason = f"ROCm {rocm_version_tuple} is available but {version} required" + if ( + rocm_version_tuple is None + or version is None + or rocm_version_tuple < tuple(version) + ): + reason = ( + f"ROCm {rocm_version_tuple} is available but {version} required" + ) raise unittest.SkipTest(reason) return fn(self, *args, **kwargs) return wrap_fn + return dec_fn + # Skips a test on CUDA when using ROCm. def skipCUDAIfNotMiopenSuggestNHWC(fn): - return skipCUDAIf(not TEST_WITH_MIOPEN_SUGGEST_NHWC, "test doesn't currently work without MIOpen NHWC activation")(fn) + return skipCUDAIf( + not TEST_WITH_MIOPEN_SUGGEST_NHWC, + "test doesn't currently work without MIOpen NHWC activation", + )(fn) + # Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s. -def skipCUDAVersionIn(versions : List[Tuple[int, int]] = None): +def skipCUDAVersionIn(versions: List[Tuple[int, int]] = None): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): @@ -1534,10 +1739,12 @@ def wrap_fn(self, *args, **kwargs): return fn(self, *args, **kwargs) return wrap_fn + return dec_fn + # Skips a test for CUDA versions less than specified, given in the form of [major, minor]. -def skipCUDAIfVersionLessThan(versions : Tuple[int, int] = None): +def skipCUDAIfVersionLessThan(versions: Tuple[int, int] = None): def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): @@ -1550,15 +1757,16 @@ def wrap_fn(self, *args, **kwargs): return fn(self, *args, **kwargs) return wrap_fn + return dec_fn + # Skips a test on CUDA if cuDNN is unavailable or its version is lower than requested. def skipCUDAIfCudnnVersionLessThan(version=0): - def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): - if self.device_type == 'cuda': + if self.device_type == "cuda": if self.no_cudnn: reason = "cuDNN not available" raise unittest.SkipTest(reason) @@ -1569,46 +1777,69 @@ def wrap_fn(self, *args, **kwargs): return fn(self, *args, **kwargs) return wrap_fn + return dec_fn + # Skips a test on CUDA if cuSparse generic API is not available def skipCUDAIfNoCusparseGeneric(fn): - return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")(fn) + return skipCUDAIf(not TEST_CUSPARSE_GENERIC, "cuSparse Generic API not available")( + fn + ) + def skipCUDAIfNoHipsparseGeneric(fn): - return skipCUDAIf(not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available")(fn) + return skipCUDAIf( + not TEST_HIPSPARSE_GENERIC, "hipSparse Generic API not available" + )(fn) + def skipCUDAIfNoSparseGeneric(fn): - return skipCUDAIf(not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), "Sparse Generic API not available")(fn) + return skipCUDAIf( + not (TEST_CUSPARSE_GENERIC or TEST_HIPSPARSE_GENERIC), + "Sparse Generic API not available", + )(fn) + def skipCUDAIfNoCudnn(fn): return skipCUDAIfCudnnVersionLessThan(0)(fn) + def skipCUDAIfMiopen(fn): return skipCUDAIf(torch.version.hip is not None, "Marked as skipped for MIOpen")(fn) + def skipCUDAIfNoMiopen(fn): - return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")(skipCUDAIfNoCudnn(fn)) + return skipCUDAIf(torch.version.hip is None, "MIOpen is not available")( + skipCUDAIfNoCudnn(fn) + ) + def skipLazy(fn): return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) + def skipMeta(fn): return skipMetaIf(True, "test doesn't work with meta tensors")(fn) + def skipXLA(fn): return skipXLAIf(True, "Marked as skipped for XLA")(fn) + def skipMPS(fn): return skipMPSIf(True, "test doesn't work on MPS backend")(fn) + def skipHPU(fn): return skipHPUIf(True, "test doesn't work on HPU backend")(fn) + def skipPRIVATEUSE1(fn): return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn) + # TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now. # This should probably enumerate all available device type test base classes. def get_all_device_types() -> List[str]: - return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"] diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 4c3f8484ed3109..26b44e2b5baae8 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -8,92 +8,138 @@ # Functions and classes for describing the dtypes a function supports # NOTE: these helpers should correspond to PyTorch's C++ dispatch macros + # Verifies each given dtype is a torch.dtype def _validate_dtypes(*dtypes): for dtype in dtypes: assert isinstance(dtype, torch.dtype) return dtypes + # class for tuples corresponding to a PyTorch dispatch macro class _dispatch_dtypes(tuple): def __add__(self, other): assert isinstance(other, tuple) return _dispatch_dtypes(tuple.__add__(self, other)) + _empty_types = _dispatch_dtypes(()) + + def empty_types(): return _empty_types + _floating_types = _dispatch_dtypes((torch.float32, torch.float64)) + + def floating_types(): return _floating_types + _floating_types_and_half = _floating_types + (torch.half,) + + def floating_types_and_half(): return _floating_types_and_half + def floating_types_and(*dtypes): return _floating_types + _validate_dtypes(*dtypes) + _floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble) + + def floating_and_complex_types(): return _floating_and_complex_types + def floating_and_complex_types_and(*dtypes): return _floating_and_complex_types + _validate_dtypes(*dtypes) + _double_types = _dispatch_dtypes((torch.float64, torch.complex128)) + + def double_types(): return _double_types + # NB: Does not contain uint16/uint32/uint64 for BC reasons -_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)) +_integral_types = _dispatch_dtypes( + (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) +) + + def integral_types(): return _integral_types + def integral_types_and(*dtypes): return _integral_types + _validate_dtypes(*dtypes) + _all_types = _floating_types + _integral_types + + def all_types(): return _all_types + def all_types_and(*dtypes): return _all_types + _validate_dtypes(*dtypes) + _complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble)) + + def complex_types(): return _complex_types + def complex_types_and(*dtypes): return _complex_types + _validate_dtypes(*dtypes) + _all_types_and_complex = _all_types + _complex_types + + def all_types_and_complex(): return _all_types_and_complex + def all_types_and_complex_and(*dtypes): return _all_types_and_complex + _validate_dtypes(*dtypes) + _all_types_and_half = _all_types + (torch.half,) + + def all_types_and_half(): return _all_types_and_half + def custom_types(*dtypes): """Create a list of arbitrary dtypes""" return _empty_types + _validate_dtypes(*dtypes) + # The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro + # See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. -def get_all_dtypes(include_half=True, - include_bfloat16=True, - include_bool=True, - include_complex=True, - include_complex32=False, - include_qint=False, - ) -> List[torch.dtype]: - dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) +def get_all_dtypes( + include_half=True, + include_bfloat16=True, + include_bool=True, + include_complex=True, + include_complex32=False, + include_qint=False, +) -> List[torch.dtype]: + dtypes = get_all_int_dtypes() + get_all_fp_dtypes( + include_half=include_half, include_bfloat16=include_bfloat16 + ) if include_bool: dtypes.append(torch.bool) if include_complex: @@ -102,12 +148,23 @@ def get_all_dtypes(include_half=True, dtypes += get_all_qint_dtypes() return dtypes + def get_all_math_dtypes(device) -> List[torch.dtype]: - return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), - include_bfloat16=False) + get_all_complex_dtypes() + return ( + get_all_int_dtypes() + + get_all_fp_dtypes( + include_half=device.startswith("cuda"), include_bfloat16=False + ) + + get_all_complex_dtypes() + ) + def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: - return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] + return ( + [torch.complex32, torch.complex64, torch.complex128] + if include_complex32 + else [torch.complex64, torch.complex128] + ) def get_all_int_dtypes() -> List[torch.dtype]: From ba19ed9a1a4efc7fa9ccdd6ace6a85fd1f43b49f Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 13 Jun 2024 14:52:04 -0700 Subject: [PATCH 037/171] FunctionalTensor: dispatch metadata directly to inner tensor (#127927) Fixes https://github.com/pytorch/pytorch/issues/127374 The error in the linked repro is: ``` AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sym_storage_offset.default(_to_functional_tensor(FakeTensor(..., device='cuda:0', size=(16, 4), dtype=torch.uint8), device='cuda:0')) ``` Where we hit FakeTensor.__torch_dispatch__, but our input is a C++ `FunctionalTensorWrapper`. What should actually have happened is that the call to `aten.sym_storage_offset` hits the `Functionalize` dispatch key, which should remove the `FunctionalTensorWrapper` and redispatch. I spent some time debugging and haven't actually figured out why this isn't happening. Instead, this PR just skips that step completely, and asks `FunctionalTensor` to directly unwrap the C++ `FunctionalTensorWrapper` when querying tensor metadata. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127927 Approved by: https://github.com/tugsbayasgalan --- torch/_subclasses/functional_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 4040774fe225f7..193c827a497bbe 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -174,10 +174,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): torch.ops.aten.is_contiguous.memory_format, ]: assert len(args) == 2 and isinstance(args[0], FunctionalTensor) - return func(args[0].elem, args[1]) + return func(torch._from_functional_tensor(args[0].elem), args[1]) assert len(args) == 1 and isinstance(args[0], FunctionalTensor) - return func(args[0].elem) + return func(torch._from_functional_tensor(args[0].elem)) # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: # - _make_wrapper_subclass requires a __torch_dispatch__ # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, From 271852aa7e400da51bcba8defb02e40596b3513e Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 13 Jun 2024 14:52:05 -0700 Subject: [PATCH 038/171] inductor: pre-grad bmm pass shouldn't match if output is mutated (#128570) This PR is enough to get this test to pass when using `TORCHDYNAMO_INLINE_INBUILT_NN_MODULES`: ``` TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python test/inductor/test_group_batch_fusion.py -k TestPostGradBatchLinearFusion.test_batch_linear_post_grad_fusion ``` inductor has a pre-grad pass to swap out multiple `linear` layers with with `addbmm`, but it also needs to insert an `unbind()` at the end. If that unbind is then followed by a mutation (like `add_()`), the autograd engine will complain (autograd does not let you mutate the output of multiple-out-view ops like unbind). I made a tweak to the pattern matching logic to avoid matching if the output of the linear is used in an op that mutates its input. My hope is that: (1) this situation is rare enough that it won't materially impact pattern matching in real world code (2) I had to use a heuristic for "is an op a mutable op", since the graph we get is from dynamo, so it can contain code like `operator.iadd` in it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128570 Approved by: https://github.com/eellison, https://github.com/mlazos ghstack dependencies: #127927 --- .../_inductor/fx_passes/group_batch_fusion.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 9a9d4cd136daec..35f1b52990fd98 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -569,6 +569,22 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]): return True +# Poor person's check for if a node in the graph mutates its input. +# (the graph is torch IR, so we will see torch fns and python operators) +def _is_mutable_node(tgt): + if str(tgt).endswith("_"): + # e.g. torch.mul_, torch.Tensor.mul_ + return True + if ( + hasattr(tgt, "__module__") + and tgt.__module__ == "_operator" + and tgt.__name__.startswith("i") + ): + # e.g. operator.iand, operator.imul + return True + return False + + def is_linear_node_can_be_fused(node: torch.fx.Node): input = get_arg_value(node, 0, "input") weight = get_arg_value(node, 1, "weight") @@ -578,6 +594,10 @@ def is_linear_node_can_be_fused(node: torch.fx.Node): and is_node_meta_valid(weight) and len(input.meta["example_value"].shape) == 2 and len(weight.meta["example_value"].shape) == 2 + # the mm -> bmm transform adds an unbind() op, + # which is not safe for autograd when the output of the mm is mutated. + # don't pattern match if any users of the mm mutate the input. + and not any(_is_mutable_node(user.target) for user in node.users) ) From d67923b955c9ca9a06ec54c344e0b960a8856b1c Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Fri, 14 Jun 2024 11:24:21 -0700 Subject: [PATCH 039/171] Adding kwargs to composable AC API to enable full capabilities (#128516) Summary: Firstly, this does not change any existing behaviour, since all the default values for kwargs were hardcoded into the ``_checkpoint_without_reentrant_generator`` call. Secondly, this is needed for unlocking the full potential of composable checkpointing making it equivalent to ``torch.utils.checkpoint.checkpoint(use_reentrant=False)``. Finally, an added benefit is now composable checkpointing can be used under ``FakeTensorMode`` by passing ``preserve_rng_state=False``. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128516 Approved by: https://github.com/awgu --- .../_composable/test_checkpoint.py | 116 +++++++++++++++++- .../_composable/checkpoint_activation.py | 56 +++++++-- 2 files changed, 157 insertions(+), 15 deletions(-) diff --git a/test/distributed/_composable/test_checkpoint.py b/test/distributed/_composable/test_checkpoint.py index d4c53e167b6504..bfd830818a9792 100644 --- a/test/distributed/_composable/test_checkpoint.py +++ b/test/distributed/_composable/test_checkpoint.py @@ -2,8 +2,9 @@ import unittest from collections import deque, OrderedDict -from contextlib import ContextDecorator +from contextlib import ContextDecorator, contextmanager, nullcontext from copy import deepcopy +from functools import partial from typing import Tuple import torch @@ -11,6 +12,7 @@ from torch.distributed._composable import checkpoint from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils.checkpoint import CheckpointError class MemoryDelta(ContextDecorator): @@ -68,7 +70,7 @@ def __init__(self, device: torch.device): self.w1 = nn.Parameter(torch.randn((100, 100), device=device)) self.w2 = nn.Parameter(torch.randn((100, 100), device=device)) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: z = x @ self.w1 z = nn.functional.relu(z) z = z @ self.w2 @@ -219,6 +221,116 @@ def forward(self, x): self.assertEqual(None, checkpoint.state(m)._ac_generator) + def test_checkpoint_kwargs(self): + class MyModel(torch.nn.Module): + def __init__(self, raise_exp: bool, change_shape_in_recomp: bool): + super().__init__() + self.fwd_count = 0 + self.raise_exp = raise_exp + self.change_shape_in_recomp = change_shape_in_recomp + self.a = torch.nn.Linear(2, 2) + + def forward(self, x): + if self.raise_exp and self.fwd_count == 0: + raise RuntimeError("foo") + if self.raise_exp and self.fwd_count == 1: + raise RuntimeError("bar") + if self.change_shape_in_recomp and self.fwd_count == 1: + x.relu_() + random_tensor = torch.randn(1, 2) + x = self.a(x + random_tensor) + self.fwd_count += 1 + return x + + m = MyModel(True, False) + m0, m1, m2, m3 = (deepcopy(m) for _ in range(4)) + + # composable checkpoint does not support use_reentrant=True + with self.assertRaisesRegex( + NotImplementedError, + "use_reentrant=True is not supported in composable checkpoint. " + "Please use torch.utils.checkpoint.checkpoint instead.", + ): + checkpoint(m, use_reentrant=True) + + # check giving an unsupported kwarg + with self.assertRaisesRegex(ValueError, "Unexpected keyword arguments: foo"): + checkpoint(m0, foo="bar") + + handled_fwd_exp = False + handled_recomp_exp = False + + @contextmanager + def fwd_ctx(mod: MyModel): + try: + mod.raise_exp = False + yield + finally: + nonlocal handled_fwd_exp + handled_fwd_exp = True + mod.raise_exp = True + + @contextmanager + def recomp_ctx(mod: MyModel): + try: + mod.raise_exp = False + yield + finally: + nonlocal handled_recomp_exp + handled_recomp_exp = True + mod.raise_exp = True + + # Test different context functions + x = torch.randn(1, 2, requires_grad=True) + checkpoint( + m1, context_fn=lambda: (partial(fwd_ctx, m1)(), partial(recomp_ctx, m1)()) + ) + m1(x.clone()).sum().backward() + self.assertEqual((handled_fwd_exp, handled_recomp_exp), (True, True)) + + checkpoint(m2, context_fn=lambda: (nullcontext(), partial(recomp_ctx, m2)())) + with self.assertRaisesRegex(RuntimeError, "foo"): + m2(x.clone()) + + handled_fwd_exp = False # Reset flag + checkpoint(m3, context_fn=lambda: (partial(fwd_ctx, m3)(), nullcontext())) + with self.assertRaisesRegex(RuntimeError, "bar"): + m3(x.clone()).sum().backward() + self.assertEqual(handled_fwd_exp, True) + + # Test determinism check failure + m4 = MyModel(False, True) + m5 = deepcopy(m4) + # Determinism check should not throw an error, + # but autograd should throw a RuntimeError + checkpoint(m4, determinism_check="none") + with self.assertRaises(RuntimeError): + m4(x.clone()).sum().backward() + + # Determinism check should throw a CheckpointError + checkpoint(m5, determinism_check="default") + with self.assertRaises(CheckpointError): + m5(x.clone()).sum().backward() + + # Test preserving random state + m6 = MyModel(False, False) + m7, m8 = (deepcopy(m6) for _ in range(2)) + checkpoint(m7, preserve_rng_state=False) + checkpoint(m8, preserve_rng_state=True) + + for mi in (m6, m7, m8): + torch.manual_seed(42) + loss = mi(x.clone()).sum() + torch.manual_seed(41) + loss.backward() + # check that m6 and m7 have at least one different grad + self.assertNotEqual( + (p1.grad for p1 in m6.parameters()), (p2.grad for p2 in m7.parameters()) + ) + # check that m6 and m8 have identical grads + for p1, p2 in zip(m6.parameters(), m8.parameters()): + self.assertEqual(p1.grad, p2.grad) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 6716f43a74a066..f1464d9e80459f 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from contextlib import contextmanager, nullcontext -from typing import Any, Tuple +from typing import Any, ContextManager, Dict, Optional, Tuple import torch import torch.nn as nn @@ -13,21 +13,23 @@ @contextmanager -def _no_hook(module: nn.Module): +def _no_hook(module: nn.Module, user_ctx: Optional[ContextManager] = None): r""" Disable hooks installed by checkpoint to avoid unintentional recursion during backward recomputation. """ - orig_enable_hook = checkpoint.state(module).enable_hook - checkpoint.state(module).enable_hook = False - try: - yield - finally: - checkpoint.state(module).enable_hook = orig_enable_hook + + with user_ctx if user_ctx else nullcontext(): + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook @contract() -def checkpoint(module: nn.Module) -> nn.Module: +def checkpoint(module: nn.Module, **kwargs) -> nn.Module: r""" This is a composable activation checkpointing API. Unlike functional activation checkpointing APIs, this one does not require changing model @@ -61,16 +63,44 @@ def checkpoint(module: nn.Module) -> nn.Module: """ torch._C._log_api_usage_once("torch.distributed.checkpoint") - def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None: + use_reentrant = kwargs.pop("use_reentrant", False) + if use_reentrant: + raise NotImplementedError( + "use_reentrant=True is not supported in composable checkpoint. " + "Please use torch.utils.checkpoint.checkpoint instead." + ) + preserve_rng_state = kwargs.pop("preserve_rng_state", True) + user_context_fns = kwargs.pop("context_fn", None) + determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) + debug = kwargs.pop("debug", False) + + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def forward_pre_hook( + module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> None: if checkpoint.state(module).enable_hook: def context_fns(): - return nullcontext(), _no_hook(module) + if user_context_fns is not None: + ctx1, ctx2 = user_context_fns() + return ctx1, _no_hook(module, ctx2) + else: + return nullcontext(), _no_hook(module) checkpoint.state( module )._ac_generator = _checkpoint_without_reentrant_generator( - module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *inputs + module, + preserve_rng_state, + context_fns, + determinism_check, + debug, + *args, + **kwargs, ) next(checkpoint.state(module)._ac_generator) @@ -90,6 +120,6 @@ def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any checkpoint.state(module)._ac_generator = None checkpoint.state(module).enable_hook = True - module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) module.register_forward_hook(forward_hook, prepend=True, always_call=True) return module From 5d9a609b4f6c94fb930188e4d7c99f53d989c022 Mon Sep 17 00:00:00 2001 From: angelayi Date: Sat, 15 Jun 2024 00:26:02 +0000 Subject: [PATCH 040/171] [export] Add print_readable to unflattener (#128617) Taking inspiration from `GraphModule.print_readable` (aka I copied its [code](https://github.com/pytorch/pytorch/blob/17b45e905a82908c7dd93cdefa3f460561f945d3/torch/fx/graph_module.py#L824)), I added a `print_readable` to the unflattened module, because it's kind of nontrivial to print the contents of this module. Example print from `python test/export/test_unflatten.py -k test_unflatten_nested` ``` class UnflattenedModule(torch.nn.Module): def forward(self, x: "f32[2, 3]"): # No stacktrace found for following nodes rootparam: "f32[2, 3]" = self.rootparam # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:99 in forward, code: x = x * self.rootparam mul: "f32[2, 3]" = torch.ops.aten.mul.Tensor(x, rootparam); x = rootparam = None # No stacktrace found for following nodes foo: "f32[2, 3]" = self.foo(mul); mul = None bar: "f32[2, 3]" = self.bar(foo); foo = None return (bar,) class foo(torch.nn.Module): def forward(self, mul: "f32[2, 3]"): # No stacktrace found for following nodes child1param: "f32[2, 3]" = self.child1param nested: "f32[2, 3]" = self.nested(mul); mul = None # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:79 in forward, code: return x + self.child1param add: "f32[2, 3]" = torch.ops.aten.add.Tensor(nested, child1param); nested = child1param = None return add class nested(torch.nn.Module): def forward(self, mul: "f32[2, 3]"): # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:67 in forward, code: return x / x div: "f32[2, 3]" = torch.ops.aten.div.Tensor(mul, mul); mul = None return div class bar(torch.nn.Module): def forward(self, add: "f32[2, 3]"): # No stacktrace found for following nodes child2buffer: "f32[2, 3]" = self.child2buffer # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:87 in forward, code: return x - self.child2buffer sub: "f32[2, 3]" = torch.ops.aten.sub.Tensor(add, child2buffer); add = child2buffer = None return sub ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128617 Approved by: https://github.com/zhxchen17, https://github.com/pianpwk --- test/dynamo/test_autograd_function.py | 4 +- test/dynamo/test_higher_order_ops.py | 12 ++--- test/dynamo/test_subclasses.py | 6 +-- torch/export/unflatten.py | 33 +++++++++++++ torch/fx/graph_module.py | 69 ++++++++++++++++++++------- 5 files changed, 95 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 814844d449ccbb..330d15a2475d63 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -525,14 +525,14 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None return (autograd_function_apply,) - class GraphModule(torch.nn.Module): + class fwd_body_0(torch.nn.Module): def forward(self, function_ctx, l_x_: "f32[]", l_z_: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = l_x_.clone(); l_x_ = None mul_1: "f32[]" = mul * clone; mul = clone = None return (mul_1, [l_weird_b, l_weird_c]) - class GraphModule(torch.nn.Module): + class bwd_body_0(torch.nn.Module): def forward(self, function_ctx, mul_1: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): _set_grad_enabled = torch._C._set_grad_enabled(False) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index dca6d28d1912dd..d38611197bdbb3 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -334,7 +334,7 @@ def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"): getitem: "f32[]" = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"): sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None @@ -372,7 +372,7 @@ def forward(self, L_x_: "f32[3, 1]"): getitem: "f32[3]" = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3, 1]"): view: "f32[3]" = l_x_.view(3); l_x_ = None add: "f32[3]" = view + 0.5; view = None @@ -394,7 +394,7 @@ def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"): getitem: "f32[s0]" = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"): view: "f32[s0]" = l_x_.view(size); l_x_ = size = None add: "f32[s0]" = view + 0.5; view = None @@ -1791,7 +1791,7 @@ def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"): getitem_1: "f32[3]" = wrap[1]; wrap = None return (getitem, getitem_1) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"): add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None @@ -1990,7 +1990,7 @@ def forward(self, L_x_: "f32[2, 3]"): add: "f32[2, 3]" = a + b; a = b = None return (add,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[2, 3]"): sin: "f32[2, 3]" = l_x_.sin() cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None @@ -2025,7 +2025,7 @@ def forward(self, L_x_: "f32[3]"): getitem: "f32[3]" = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3]"): neg: "f32[3]" = -l_x_; l_x_ = None return (neg,) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 954859f50994f4..1c0be0f3adf5cc 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -927,7 +927,7 @@ def forward(self, L_x_: "f32[3, 4]"): getitem: "f32[3, 4]" = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None return (add_,) @@ -951,7 +951,7 @@ def forward(self, L_x_ : torch.Tensor): getitem = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_): add_ = l_x_.add_(1.0); l_x_ = None return (add_,) @@ -981,7 +981,7 @@ def forward(self, L_x_ : torch.Tensor): getitem = wrap[0]; wrap = None return (getitem,) - class GraphModule(torch.nn.Module): + class wrap_body_0(torch.nn.Module): def forward(self, l_x_): add_ = l_x_.add_(1.0); l_x_ = None return (add_,) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 7b6ed6f1b5a974..58bd2607602882 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -21,6 +21,7 @@ TensorArgument, ) from torch.fx._symbolic_trace import is_fx_tracing +from torch.fx.graph_module import _print_readable from torch.utils._pytree import GetAttrKey, SequenceKey from ._remove_effect_tokens_pass import _remove_effect_tokens @@ -133,6 +134,22 @@ def finalize(self): if node.op == "placeholder": self.arg_names.append(node.target) + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "InterpreterModule", + print_output, + include_stride, + include_device, + colored, + ) + class FlatArgsAdapter(abc.ABC): """ @@ -465,6 +482,22 @@ def forward(self, *args, **kwargs): ) return pytree.tree_unflatten(tree_out, signature.out_spec) + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "UnflattenedModule", + print_output, + include_stride, + include_device, + colored, + ) + def unflatten( module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 995f5cbc2a34ee..cc3bd3b2311d4a 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -257,6 +257,50 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): setattr(to_module, field, from_obj) +def _print_readable( + module, + module_name, + print_output=True, + include_stride=False, + include_device=False, + colored=False, +): + graph = module.graph + assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph" + + verbose_python_code = graph.python_code( + root_module="self", + verbose=True, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _print_readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls @@ -825,25 +869,14 @@ def print_readable(self, print_output=True, include_stride=False, include_device """ Return the Python code generated for current GraphModule and its children GraphModules """ - verbose_python_code = self._graph.python_code( - root_module="self", verbose=True, include_stride=include_stride, include_device=include_device, colored=colored + return _print_readable( + self, + self._get_name(), + print_output, + include_stride, + include_device, + colored, ) - module_code = verbose_python_code.src - module_code = module_code.lstrip("\n") - module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code - module_code = _addindent(module_code, 4) - - submodule_code_list = [""] - for submodule in self.children(): - if isinstance(submodule, GraphModule): - submodule_code_list.append(submodule.print_readable(print_output=False)) - submodule_code = "\n".join(submodule_code_list) - submodule_code = _addindent(submodule_code, 4) - - output = module_code + submodule_code - if print_output: - print(module_code + submodule_code) - return output def __str__(self) -> str: orig_str = super().__str__() From 6616ad030f9ee47cfb3796bf4af1277d5b845ad5 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 13 Jun 2024 22:05:18 -0700 Subject: [PATCH 041/171] [Inductor] Fix the High Order Op layout issue (#128275) Fix the issue: https://github.com/pytorch/pytorch/issues/127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128275 Approved by: https://github.com/eellison --- test/higher_order_ops/test_with_effects.py | 27 ++++++++++++++++++++++ torch/_inductor/ir.py | 5 ++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index cd8b80e8e88677..9e453ea2d7447d 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -198,6 +198,33 @@ def f(x): res = torch.compile(f, backend="inductor")(*inputs) self.assertTrue(torch.allclose(res, f(*inputs))) + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") + @skipIfNoDynamoSupport + def test_compile_inductor_external_op_return_none(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::inplace_add", + "(Tensor input, Tensor(a!) output) -> ()", + lib=lib, + ) + + def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.device == output.device + output.add_(input) + + lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd") + + def f(x): + out = torch.empty(3) + out = torch.zeros_like(out) + torch.ops.mylib.inplace_add(x, out) + return out + + inputs = (torch.randn(3),) + + res = torch.compile(f, backend="inductor")(*inputs) + self.assertTrue(torch.allclose(res, f(*inputs))) + def test_compile_aot_eager_requires_grad(self): def f(x): torch.ops.aten._print("moo") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8609802309ea48..9e1c90e9953783 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5648,9 +5648,10 @@ def create(cls, kernel, *args, **kwargs): unbacked_bindings, ) = cls.process_kernel(kernel, *args, **kwargs) + device = cls.find_device(tensor_args, example_output) if example_output is None: packed = cls( - NoneLayout(None), + NoneLayout(device), kernel, tensor_args, non_tensor_args, @@ -5659,9 +5660,7 @@ def create(cls, kernel, *args, **kwargs): ) else: - device = cls.find_device(tensor_args, example_output) assert device, "Not sure where to find device info" - packed = cls( MultiOutputLayout(device), kernel, From 73ba432d32a766e6d4e1e92175ebba91cb45fdf3 Mon Sep 17 00:00:00 2001 From: Yueming Hao Date: Sat, 15 Jun 2024 00:41:34 +0000 Subject: [PATCH 042/171] [custom_op]Fix None return schema (#128667) Fixes #125044 If users define a schema returns `None`, it will be parsed to a `torch.NoneType`. Auto functionalization support the `()` as a empty return but not for `None`. So, `None` return fails the check for [`can_auto_functionalize`](https://github.com/pytorch/pytorch/blob/findhao/fix_none_return_functionalize/torch/_higher_order_ops/auto_functionalize.py#L71) even we can take this as a `()` return. This PR is a fix to skip the check for None return. I hope it can be fixed in a [deeper level](https://github.com/pytorch/pytorch/pull/128667/commits/31e44c72ca424adeecc3ef022d79906e3b6b54db), but this fix breaks a lot of existing schemas. So it's better to fix this issue in the auto_functionalize.py at this moment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128667 Approved by: https://github.com/zou3519 --- test/dynamo/test_misc.py | 17 +++++++++++++++++ torch/_higher_order_ops/auto_functionalize.py | 3 +++ 2 files changed, 20 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 7a744a56c56326..221d826aa11214 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -574,6 +574,23 @@ def f(a, mode): cleanup_op("mylib::foo") del lib + def test_auto_functionalize_can_with_none_return(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo(Tensor x, Tensor(a!) out) -> None") + + def foo_impl(x, out): + out.copy_(x) + + lib.impl("foo", foo_impl, "CompositeExplicitAutograd") + x = torch.randn(3) + out = torch.zeros(3) + + @torch.compile + def f(x, out): + torch.ops.mylib.foo(x, out) + + f(x, out) + def test_user_defined_setattr1(self): @torch.compile(backend="eager", fullgraph=True) def fn(obj): diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 189f746b77a010..ac93570a3838b6 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -97,6 +97,9 @@ def can_auto_functionalize(op: torch._ops.OperatorBase) -> bool: # Tensor[], Tensor?[], Tensor[]?. return False + if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType): + # Skip schema returns -> None + return True # The returns must not alias anything for ret in schema.returns: if ret.alias_info is None and type(ret.type) is torch.TensorType: From 3f47c72268cd2e68d6207087d51a895f149d989b Mon Sep 17 00:00:00 2001 From: Fuzzkatt Date: Sat, 15 Jun 2024 01:32:53 +0000 Subject: [PATCH 043/171] add multiprocessing checks in test_dataloader.py (#128244) Add multiprocessing checks in test_dataloader.py for tests requiring multiprocessing similar to test_multiprocessing.py: https://github.com/pytorch/pytorch/blob/main/test/test_multiprocessing.py#L41-L52. Change all Jetson skips to TEST_CUDA_IPC checks since that is the root cause of the failures on Jetson in the first place. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128244 Approved by: https://github.com/eqy, https://github.com/malfet --- test/test_dataloader.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 7993d083e92592..e332b650c0a262 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -37,6 +37,7 @@ TEST_CUDA, TEST_NUMPY, TEST_WITH_ASAN, + TEST_WITH_ROCM, TEST_WITH_TSAN, TestCase, ) @@ -85,7 +86,17 @@ # sharding on sandcastle. This line silences flake warnings load_tests = load_tests -if TEST_CUDA: +TEST_CUDA_IPC = ( + torch.cuda.is_available() + and sys.platform != "darwin" + and sys.platform != "win32" + and not IS_JETSON + and not TEST_WITH_ROCM +) # https://github.com/pytorch/pytorch/issues/90940 + +TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1 + +if TEST_CUDA_IPC: torch.cuda.memory._set_allocator_settings("expandable_segments:False") if not NO_MULTIPROCESSING_SPAWN: @@ -1352,7 +1363,7 @@ def test_sequential_pin_memory(self): self.assertTrue(input.is_pinned()) self.assertTrue(target.is_pinned()) - @unittest.skipIf(IS_JETSON, "Not working on Jetson") + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") def test_multiple_dataloaders(self): for multiprocessing_context in supported_multiprocessing_contexts: loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1)) @@ -1830,7 +1841,7 @@ def test_chain_iterable_style_dataset(self): list(iter(ChainDataset([dataset1, self.dataset]))) @unittest.skipIf(IS_MACOS, "Not working on macos") - @unittest.skipIf(IS_MACOS or IS_JETSON, "Not working on macos or Jetson") + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") @skipIfRocm # https://github.com/pytorch/pytorch/issues/90940 def test_multiprocessing_contexts(self): reference = [ @@ -1919,13 +1930,13 @@ def _test_multiprocessing_iterdatapipe(self, with_dill): ) @skipIfNoNumpy - @unittest.skipIf(IS_JETSON, "Not working on Jetson") + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") def test_multiprocessing_iterdatapipe(self): self._test_multiprocessing_iterdatapipe(with_dill=False) @unittest.expectedFailure @skipIfNoNumpy - @unittest.skipIf(IS_JETSON, "Not working on Jetson") + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") @skipIfNoDill def test_multiprocessing_iterdatapipe_with_dill(self): self._test_multiprocessing_iterdatapipe(with_dill=True) @@ -2878,6 +2889,7 @@ class TestDataLoaderDeviceType(TestCase): "context", [ctx for ctx in supported_multiprocessing_contexts if ctx is not None], ) + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") def test_nested_tensor_multiprocessing(self, device, context): # The 'fork' multiprocessing context doesn't work for CUDA so skip it if "cuda" in device and context == "fork": From f37121bb74cdeaac6e9ce3e46cafc8b7972e9664 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sat, 15 Jun 2024 01:39:48 +0000 Subject: [PATCH 044/171] Add model name, quantization and device to gpt_fast micro benchmark output (#128091) A small enhancement to https://hud.pytorch.org/benchmark/llms with these columns in the output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128091 Approved by: https://github.com/yanboliang --- benchmarks/gpt_fast/benchmark.py | 32 +++++++++++++++++++++----------- benchmarks/gpt_fast/generate.py | 32 ++++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 998878fa969284..16f3e55af17b04 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -21,6 +21,8 @@ class Experiment: metric: str target: float actual: float + dtype: str + device: str class SimpleMLP(nn.Module): @@ -41,7 +43,7 @@ def forward(self, x): return x -def run_mlp_layer_norm_gelu(): +def run_mlp_layer_norm_gelu(device: str = "cuda"): dtype_flops_utilization_map = { torch.bfloat16: "0.71", } @@ -53,9 +55,9 @@ def run_mlp_layer_norm_gelu(): for D in input_shapes: mod = SimpleMLP( input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype - ).to("cuda") + ).to(device) - x = torch.randn(D, device="cuda", dtype=torch.bfloat16) + x = torch.randn(D, device=device, dtype=torch.bfloat16) with FlopCounterMode(display=False) as mode: mod(x) @@ -78,12 +80,14 @@ def run_mlp_layer_norm_gelu(): "flops_utilization", expected_flops_utilization, f"{flops_utilization:.02f}", + dtype_str, + device, ) ) return results -def run_layer_norm(): +def run_layer_norm(device: str = "cuda"): dtype_memory_bandwidth_map = { torch.bfloat16: "1017", } @@ -93,9 +97,9 @@ def run_layer_norm(): for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): memory_bandwidth = 0 for D in input_shapes: - mod = nn.LayerNorm(D).to("cuda") + mod = nn.LayerNorm(D).to(device) - x = torch.randn(BS, D, device="cuda", dtype=dtype) + x = torch.randn(BS, D, device=device, dtype=dtype) compiled_mod = torch.compile(mod, dynamic=False) @@ -113,13 +117,15 @@ def run_layer_norm(): "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", + dtype_str, + device, ) ) return results @torch._inductor.config.patch(coordinate_descent_tuning=True) -def run_gather_gemv(): +def run_gather_gemv(device: str = "cuda"): E = 8 dtype_memory_bandwidth_map = { torch.int8: "1113", @@ -134,9 +140,9 @@ def run_gather_gemv(): def gather_gemv(W, score_idxs, x): return W[score_idxs].to(x.dtype) @ x - W = torch.randn(E, D, D, device="cuda").to(dtype=dtype) - x = torch.randn(D, device="cuda", dtype=torch.bfloat16) - score_idxs = torch.tensor([3, 5], device="cuda") + W = torch.randn(E, D, D, device=device).to(dtype=dtype) + x = torch.randn(D, device=device, dtype=torch.bfloat16) + score_idxs = torch.tensor([3, 5], device=device) compiled_fn = torch.compile(gather_gemv, dynamic=False) @@ -154,13 +160,15 @@ def gather_gemv(W, score_idxs, x): "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", + dtype_str, + device, ) ) return results @torch._inductor.config.patch(coordinate_descent_tuning=True) -def run_gemv(): +def run_gemv(device: str = "cuda"): dtype_memory_bandwidth_map = { torch.int8: "990", torch.bfloat16: "1137", @@ -193,6 +201,8 @@ def gemv(W, x): "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", + dtype_str, + device, ) ) return results diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index a4e4b06c79d7b1..3ec72bf1e3195e 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -172,8 +172,8 @@ def run_experiment( max_new_tokens: int = 200, top_k: int = 200, temperature: float = 0.8, + device: str = "cuda", ) -> None: - device = "cuda" print(f"Loading model {x.name}") t0 = time.time() model = _load_model(x) @@ -221,7 +221,7 @@ def run_experiment( # token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. -def run_llama2_7b_bf16(): +def run_llama2_7b_bf16(device: str = "cuda"): from benchmark import Experiment model = GPTModelConfig( @@ -235,22 +235,26 @@ def run_llama2_7b_bf16(): token_per_sec, memory_bandwidth = run_experiment(model) return [ Experiment( - "llama2_7b_bf16", + model.name, "token_per_sec", model.token_per_sec, f"{token_per_sec:.02f}", + model.mode, + device, ), Experiment( - "llama2_7b_bf16", + model.name, "memory_bandwidth(GB/s)", model.memory_bandwidth, f"{memory_bandwidth:.02f}", + model.mode, + device, ), ] # token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. -def run_llama2_7b_int8(): +def run_llama2_7b_int8(device: str = "cuda"): from benchmark import Experiment model = GPTModelConfig( @@ -264,22 +268,26 @@ def run_llama2_7b_int8(): token_per_sec, memory_bandwidth = run_experiment(model) return [ Experiment( - "llama2_7b_int8", + model.name, "token_per_sec", model.token_per_sec, f"{token_per_sec:.02f}", + model.mode, + device, ), Experiment( - "llama2_7b_int8", + model.name, "memory_bandwidth(GB/s)", model.memory_bandwidth, f"{memory_bandwidth:.02f}", + model.mode, + device, ), ] # token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. -def run_mixtral_8x7b_int8(): +def run_mixtral_8x7b_int8(device: str = "cuda"): from benchmark import Experiment # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. @@ -294,15 +302,19 @@ def run_mixtral_8x7b_int8(): token_per_sec, memory_bandwidth = run_experiment(model) return [ Experiment( - "mixtral_8x7b_int8", + model.name, "token_per_sec", model.token_per_sec, f"{token_per_sec:.02f}", + model.mode, + device, ), Experiment( - "mixtral_8x7b_int8", + model.name, "memory_bandwidth(GB/s)", model.memory_bandwidth, f"{memory_bandwidth:.02f}", + model.mode, + device, ), ] From 5efe71f1345a38e102971d68df5f2e985a1daf7c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 15 Jun 2024 01:46:23 +0000 Subject: [PATCH 045/171] Revert "[export] Add print_readable to unflattener (#128617)" This reverts commit 5d9a609b4f6c94fb930188e4d7c99f53d989c022. Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/huydhn due to Sorry for reverting your change but another failed test shows up in trunk inductor/test_flex_attention.py where it needs to be updated https://hud.pytorch.org/pytorch/pytorch/commit/5d9a609b4f6c94fb930188e4d7c99f53d989c022. I guess it is easier to revert and reland this ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2169030779)) --- test/dynamo/test_autograd_function.py | 4 +- test/dynamo/test_higher_order_ops.py | 12 ++--- test/dynamo/test_subclasses.py | 6 +-- torch/export/unflatten.py | 33 ------------- torch/fx/graph_module.py | 69 +++++++-------------------- 5 files changed, 29 insertions(+), 95 deletions(-) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 330d15a2475d63..814844d449ccbb 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -525,14 +525,14 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " autograd_function_apply: "f32[]" = torch._functorch.autograd_function.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None return (autograd_function_apply,) - class fwd_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, function_ctx, l_x_: "f32[]", l_z_: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = l_x_.clone(); l_x_ = None mul_1: "f32[]" = mul * clone; mul = clone = None return (mul_1, [l_weird_b, l_weird_c]) - class bwd_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, function_ctx, mul_1: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): _set_grad_enabled = torch._C._set_grad_enabled(False) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index d38611197bdbb3..dca6d28d1912dd 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -334,7 +334,7 @@ def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"): getitem: "f32[]" = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"): sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None @@ -372,7 +372,7 @@ def forward(self, L_x_: "f32[3, 1]"): getitem: "f32[3]" = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 1]"): view: "f32[3]" = l_x_.view(3); l_x_ = None add: "f32[3]" = view + 0.5; view = None @@ -394,7 +394,7 @@ def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"): getitem: "f32[s0]" = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"): view: "f32[s0]" = l_x_.view(size); l_x_ = size = None add: "f32[s0]" = view + 0.5; view = None @@ -1791,7 +1791,7 @@ def forward(self, L_arg1_0_: "f32[3]", L_arg2_0_: "f32[3]"): getitem_1: "f32[3]" = wrap[1]; wrap = None return (getitem, getitem_1) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_arg1_0_: "f32[3]", l_arg2_0_: "f32[3]"): add: "f32[3]" = l_arg1_0_ + 1; l_arg1_0_ = None @@ -1990,7 +1990,7 @@ def forward(self, L_x_: "f32[2, 3]"): add: "f32[2, 3]" = a + b; a = b = None return (add,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[2, 3]"): sin: "f32[2, 3]" = l_x_.sin() cos: "f32[2, 3]" = l_x_.cos(); l_x_ = None @@ -2025,7 +2025,7 @@ def forward(self, L_x_: "f32[3]"): getitem: "f32[3]" = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3]"): neg: "f32[3]" = -l_x_; l_x_ = None return (neg,) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 1c0be0f3adf5cc..954859f50994f4 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -927,7 +927,7 @@ def forward(self, L_x_: "f32[3, 4]"): getitem: "f32[3, 4]" = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None return (add_,) @@ -951,7 +951,7 @@ def forward(self, L_x_ : torch.Tensor): getitem = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_): add_ = l_x_.add_(1.0); l_x_ = None return (add_,) @@ -981,7 +981,7 @@ def forward(self, L_x_ : torch.Tensor): getitem = wrap[0]; wrap = None return (getitem,) - class wrap_body_0(torch.nn.Module): + class GraphModule(torch.nn.Module): def forward(self, l_x_): add_ = l_x_.add_(1.0); l_x_ = None return (add_,) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 58bd2607602882..7b6ed6f1b5a974 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -21,7 +21,6 @@ TensorArgument, ) from torch.fx._symbolic_trace import is_fx_tracing -from torch.fx.graph_module import _print_readable from torch.utils._pytree import GetAttrKey, SequenceKey from ._remove_effect_tokens_pass import _remove_effect_tokens @@ -134,22 +133,6 @@ def finalize(self): if node.op == "placeholder": self.arg_names.append(node.target) - def print_readable( - self, - print_output=True, - include_stride=False, - include_device=False, - colored=False, - ): - return _print_readable( - self, - "InterpreterModule", - print_output, - include_stride, - include_device, - colored, - ) - class FlatArgsAdapter(abc.ABC): """ @@ -482,22 +465,6 @@ def forward(self, *args, **kwargs): ) return pytree.tree_unflatten(tree_out, signature.out_spec) - def print_readable( - self, - print_output=True, - include_stride=False, - include_device=False, - colored=False, - ): - return _print_readable( - self, - "UnflattenedModule", - print_output, - include_stride, - include_device, - colored, - ) - def unflatten( module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index cc3bd3b2311d4a..995f5cbc2a34ee 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -257,50 +257,6 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): setattr(to_module, field, from_obj) -def _print_readable( - module, - module_name, - print_output=True, - include_stride=False, - include_device=False, - colored=False, -): - graph = module.graph - assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph" - - verbose_python_code = graph.python_code( - root_module="self", - verbose=True, - include_stride=include_stride, - include_device=include_device, - colored=colored, - ) - module_code = verbose_python_code.src - module_code = module_code.lstrip("\n") - module_code = f"class {module_name}(torch.nn.Module):\n" + module_code - module_code = _addindent(module_code, 4) - - submodule_code_list = [""] - for submodule_name, submodule in module.named_children(): - if hasattr(submodule, "graph"): - submodule_code_list.append( - _print_readable( - submodule, - submodule_name, - print_output=False, - include_stride=include_stride, - include_device=include_device, - ) - ) - submodule_code = "\n".join(submodule_code_list) - submodule_code = _addindent(submodule_code, 4) - - output = module_code + submodule_code - if print_output: - print(module_code + submodule_code) - return output - - class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls @@ -869,14 +825,25 @@ def print_readable(self, print_output=True, include_stride=False, include_device """ Return the Python code generated for current GraphModule and its children GraphModules """ - return _print_readable( - self, - self._get_name(), - print_output, - include_stride, - include_device, - colored, + verbose_python_code = self._graph.python_code( + root_module="self", verbose=True, include_stride=include_stride, include_device=include_device, colored=colored ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule in self.children(): + if isinstance(submodule, GraphModule): + submodule_code_list.append(submodule.print_readable(print_output=False)) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output def __str__(self) -> str: orig_str = super().__str__() From 846bb30e13a534b931dfc1d27e058b63aa88d90d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 15 Jun 2024 01:58:20 +0000 Subject: [PATCH 046/171] Revert "[1/N] Change #include to #include (#128301)" This reverts commit bd72e28314d8d63bb347becb8309f5ac7761c6b5. Reverted https://github.com/pytorch/pytorch/pull/128301 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it fails XLA build https://hud.pytorch.org/pytorch/pytorch/commit/bd72e28314d8d63bb347becb8309f5ac7761c6b5. Please rebase your PR before relanding because I think the failure is hidden by an unrelated broken trunk XLA failure from your current base commit ([comment](https://github.com/pytorch/pytorch/pull/128301#issuecomment-2169035822)) --- aten/src/ATen/CPUGeneratorImpl.h | 2 +- aten/src/ATen/InferSize.h | 2 +- aten/src/ATen/SavedTensorHooks.cpp | 2 +- aten/src/ATen/SavedTensorHooks.h | 2 +- aten/src/ATen/TensorIndexing.h | 2 +- aten/src/ATen/native/BatchLinearAlgebra.h | 2 +- aten/src/ATen/record_function.h | 2 +- c10/core/ConstantSymNodeImpl.h | 6 +- c10/core/ScalarTypeToTypeMeta.h | 2 +- c10/core/SymBool.h | 4 +- c10/core/SymInt.h | 4 +- c10/core/SymIntArrayRef.h | 4 +- c10/core/SymNodeImpl.h | 12 +-- c10/core/SymbolicShapeMeta.cpp | 4 +- c10/core/TensorImpl.cpp | 2 +- c10/core/TensorImpl.h | 14 ++-- c10/core/TensorOptions.h | 32 +++---- c10/core/UndefinedTensorImpl.cpp | 2 +- c10/core/impl/InlineDeviceGuard.h | 4 +- c10/core/impl/InlineStreamGuard.h | 4 +- c10/core/impl/PyObjectSlot.h | 8 +- c10/core/impl/TorchDispatchModeTLS.cpp | 12 +-- c10/cuda/CUDACachingAllocator.cpp | 6 +- c10/cuda/CUDAFunctions.cpp | 2 +- c10/cuda/CUDAGuard.h | 4 +- c10/cuda/impl/CUDAGuardImpl.h | 4 +- c10/test/core/DeviceGuard_test.cpp | 5 +- c10/test/core/SymInt_test.cpp | 2 +- c10/test/core/impl/InlineDeviceGuard_test.cpp | 16 ++-- c10/test/core/impl/InlineStreamGuard_test.cpp | 16 ++-- c10/test/util/optional_test.cpp | 32 +++---- c10/util/Backtrace.cpp | 10 +-- c10/util/OptionalArrayRef.h | 18 ++-- c10/xpu/test/impl/XPUStreamTest.cpp | 2 +- torch/csrc/Module.cpp | 8 +- torch/csrc/Storage.cpp | 2 +- .../csrc/api/include/torch/expanding_array.h | 2 +- torch/csrc/api/include/torch/fft.h | 84 +++++++++---------- torch/csrc/api/include/torch/nested.h | 12 +-- .../include/torch/nn/functional/activation.h | 6 +- .../include/torch/nn/functional/embedding.h | 6 +- .../api/include/torch/nn/functional/loss.h | 4 +- .../torch/nn/functional/normalization.h | 6 +- .../api/include/torch/nn/functional/pooling.h | 16 ++-- .../include/torch/nn/functional/upsampling.h | 28 +++---- .../api/include/torch/nn/modules/batchnorm.h | 4 +- .../csrc/api/include/torch/nn/modules/conv.h | 6 +- .../api/include/torch/nn/modules/pooling.h | 6 +- .../csrc/api/include/torch/nn/modules/utils.h | 2 +- .../api/include/torch/nn/options/activation.h | 10 +-- .../api/include/torch/nn/options/embedding.h | 24 +++--- .../csrc/api/include/torch/nn/options/loss.h | 4 +- .../include/torch/nn/options/normalization.h | 2 +- .../api/include/torch/nn/options/pooling.h | 8 +- .../api/include/torch/nn/options/upsampling.h | 14 ++-- .../api/include/torch/nn/options/vision.h | 2 +- .../api/include/torch/nn/utils/clip_grad.h | 4 +- .../torch/nn/utils/convert_parameters.h | 2 +- torch/csrc/api/include/torch/optim/lbfgs.h | 10 +-- .../csrc/api/include/torch/optim/optimizer.h | 10 +-- .../include/torch/serialize/input-archive.h | 10 +-- torch/csrc/api/include/torch/types.h | 4 +- torch/csrc/api/src/jit.cpp | 2 +- torch/csrc/api/src/nn/modules/activation.cpp | 8 +- torch/csrc/api/src/nn/modules/conv.cpp | 2 +- torch/csrc/api/src/nn/modules/embedding.cpp | 10 +-- torch/csrc/api/src/nn/modules/pooling.cpp | 20 ++--- torch/csrc/api/src/nn/modules/upsampling.cpp | 4 +- torch/csrc/api/src/optim/lbfgs.cpp | 16 ++-- .../csrc/api/src/serialize/input-archive.cpp | 8 +- torch/csrc/autograd/FunctionsManual.cpp | 22 ++--- torch/csrc/autograd/FunctionsManual.h | 2 +- torch/csrc/autograd/TraceTypeManual.cpp | 2 +- torch/csrc/autograd/VariableTypeManual.cpp | 2 +- torch/csrc/autograd/VariableTypeUtils.h | 4 +- torch/csrc/autograd/autograd.h | 4 +- .../autograd_not_implemented_fallback.cpp | 2 +- torch/csrc/autograd/engine.cpp | 6 +- torch/csrc/autograd/function.h | 4 +- .../csrc/autograd/functions/accumulate_grad.h | 2 +- torch/csrc/autograd/functions/comm.cpp | 2 +- torch/csrc/autograd/functions/comm.h | 6 +- torch/csrc/autograd/init.cpp | 8 +- torch/csrc/autograd/input_buffer.cpp | 6 +- torch/csrc/autograd/input_buffer.h | 2 +- torch/csrc/autograd/profiler_legacy.cpp | 2 +- torch/csrc/autograd/profiler_legacy.h | 6 +- torch/csrc/autograd/profiler_python.cpp | 10 +-- torch/csrc/autograd/python_function.cpp | 2 +- torch/csrc/autograd/python_function.h | 2 +- torch/csrc/autograd/python_variable.cpp | 8 +- .../autograd/python_variable_indexing.cpp | 4 +- torch/csrc/autograd/record_function_ops.h | 4 +- .../autograd/utils/grad_layout_contract.h | 2 +- .../csrc/autograd/utils/python_arg_parsing.h | 2 +- torch/csrc/autograd/variable.h | 12 +-- torch/csrc/cuda/comm.cpp | 2 +- torch/csrc/cuda/comm.h | 8 +- torch/csrc/cuda/memory_snapshot.h | 2 +- torch/csrc/cuda/nccl.h | 2 +- torch/csrc/cuda/python_nccl.cpp | 2 +- .../autograd/engine/dist_engine.cpp | 2 +- torch/csrc/distributed/c10d/NCCLUtils.cpp | 8 +- torch/csrc/distributed/c10d/NCCLUtils.hpp | 16 ++-- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 4 +- .../distributed/c10d/ProcessGroupGloo.cpp | 4 +- .../distributed/c10d/ProcessGroupGloo.hpp | 2 +- .../csrc/distributed/c10d/ProcessGroupMPI.cpp | 6 +- .../csrc/distributed/c10d/ProcessGroupMPI.hpp | 6 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 30 +++---- .../distributed/c10d/ProcessGroupNCCL.hpp | 8 +- torch/csrc/distributed/c10d/TCPStore.cpp | 4 +- torch/csrc/distributed/c10d/TCPStore.hpp | 6 +- torch/csrc/distributed/c10d/TraceUtils.h | 4 +- torch/csrc/distributed/c10d/Types.hpp | 2 +- torch/csrc/distributed/c10d/Utils.hpp | 2 +- torch/csrc/distributed/c10d/Work.hpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 4 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 4 +- torch/csrc/distributed/c10d/logger.cpp | 2 +- torch/csrc/distributed/c10d/reducer.cpp | 8 +- torch/csrc/distributed/c10d/reducer.hpp | 6 +- torch/csrc/distributed/c10d/reducer_cuda.cpp | 4 +- torch/csrc/distributed/c10d/reducer_timer.hpp | 2 +- torch/csrc/distributed/c10d/sequence_num.cpp | 10 +-- torch/csrc/distributed/c10d/sequence_num.hpp | 2 +- .../rpc/profiler/remote_profiler_manager.cpp | 4 +- .../rpc/profiler/remote_profiler_manager.h | 2 +- torch/csrc/distributed/rpc/py_rref.cpp | 2 +- .../csrc/distributed/rpc/python_functions.cpp | 4 +- .../rpc/request_callback_no_python.cpp | 2 +- torch/csrc/distributed/rpc/rref_impl.h | 2 +- torch/csrc/distributed/rpc/script_call.h | 2 +- .../csrc/distributed/rpc/tensorpipe_cuda.cpp | 2 +- .../csrc/distributed/rpc/tensorpipe_utils.cpp | 4 +- .../csrc/dynamo/python_compiled_autograd.cpp | 2 +- torch/csrc/functorch/init.cpp | 20 ++--- torch/csrc/inductor/aoti_torch/utils.h | 26 +++--- torch/csrc/jit/api/compilation_unit.h | 6 +- torch/csrc/jit/api/function_impl.h | 2 +- torch/csrc/jit/api/module.cpp | 4 +- torch/csrc/jit/api/module.h | 8 +- torch/csrc/jit/api/object.cpp | 2 +- torch/csrc/jit/api/object.h | 6 +- torch/csrc/jit/codegen/fuser/compiler.cpp | 2 +- .../jit/codegen/fuser/cpu/fused_kernel.cpp | 4 +- torch/csrc/jit/codegen/fuser/executor.cpp | 12 +-- torch/csrc/jit/codegen/fuser/kernel_spec.h | 4 +- .../csrc/jit/codegen/onednn/graph_helper.cpp | 2 +- .../jit/codegen/onednn/graph_rewriter.cpp | 2 +- .../jit/codegen/onednn/prepare_binary.cpp | 2 +- torch/csrc/jit/cuda/cuda.h | 4 +- torch/csrc/jit/frontend/builtin_functions.cpp | 2 +- .../frontend/canonicalize_modified_loop.cpp | 2 +- .../jit/frontend/concrete_module_type.cpp | 10 +-- .../jit/frontend/function_schema_parser.cpp | 4 +- torch/csrc/jit/frontend/ir_emitter.cpp | 30 +++---- .../csrc/jit/frontend/parse_string_literal.h | 8 +- torch/csrc/jit/frontend/parser.cpp | 4 +- torch/csrc/jit/frontend/schema_matching.cpp | 28 +++---- torch/csrc/jit/frontend/schema_matching.h | 8 +- .../csrc/jit/frontend/schema_type_parser.cpp | 6 +- .../csrc/jit/frontend/script_type_parser.cpp | 20 ++--- torch/csrc/jit/frontend/source_range.cpp | 2 +- torch/csrc/jit/frontend/source_range.h | 10 +-- torch/csrc/jit/frontend/sugared_value.cpp | 2 +- torch/csrc/jit/frontend/sugared_value.h | 14 ++-- torch/csrc/jit/frontend/tracer.cpp | 4 +- torch/csrc/jit/ir/alias_analysis.cpp | 14 ++-- torch/csrc/jit/ir/constants.cpp | 12 +-- torch/csrc/jit/ir/constants.h | 14 ++-- torch/csrc/jit/ir/ir.cpp | 12 +-- torch/csrc/jit/ir/ir.h | 18 ++-- torch/csrc/jit/ir/scope.h | 2 +- .../mobile/compatibility/backport_manager.cpp | 8 +- .../compatibility/runtime_compatibility.h | 2 +- torch/csrc/jit/mobile/flatbuffer_loader.cpp | 6 +- torch/csrc/jit/mobile/flatbuffer_loader.h | 12 +-- torch/csrc/jit/mobile/frame.h | 2 +- torch/csrc/jit/mobile/function.cpp | 4 +- torch/csrc/jit/mobile/import.cpp | 6 +- torch/csrc/jit/mobile/import.h | 6 +- torch/csrc/jit/mobile/import_data.h | 6 +- .../mobile/model_tracer/MobileModelRunner.h | 2 +- .../jit/mobile/model_tracer/TracerRunner.cpp | 8 +- torch/csrc/jit/mobile/module.cpp | 6 +- torch/csrc/jit/mobile/promoted_prim_ops.cpp | 2 +- .../operator_upgraders/upgraders_entry.cpp | 2 +- torch/csrc/jit/operator_upgraders/utils.cpp | 4 +- torch/csrc/jit/operator_upgraders/utils.h | 2 +- torch/csrc/jit/passes/autocast.cpp | 8 +- torch/csrc/jit/passes/canonicalize.cpp | 8 +- .../passes/canonicalize_graph_fuser_ops.cpp | 4 +- .../csrc/jit/passes/constant_propagation.cpp | 14 ++-- .../jit/passes/create_autodiff_subgraphs.cpp | 8 +- .../csrc/jit/passes/device_type_analysis.cpp | 6 +- torch/csrc/jit/passes/dtype_analysis.cpp | 6 +- torch/csrc/jit/passes/erase_number_types.cpp | 2 +- torch/csrc/jit/passes/freeze_module.cpp | 8 +- .../csrc/jit/passes/frozen_ops_to_mkldnn.cpp | 2 +- torch/csrc/jit/passes/graph_fuser.cpp | 4 +- .../csrc/jit/passes/graph_rewrite_helper.cpp | 2 +- .../jit/passes/inline_autodiff_subgraphs.cpp | 2 +- .../jit/passes/integer_value_refinement.cpp | 4 +- torch/csrc/jit/passes/onnx/constant_fold.cpp | 56 ++++++------- torch/csrc/jit/passes/onnx/constant_fold.h | 2 +- torch/csrc/jit/passes/onnx/constant_map.cpp | 20 ++--- .../jit/passes/onnx/function_extraction.cpp | 16 ++-- .../jit/passes/onnx/list_model_parameters.cpp | 2 +- .../pattern_conversion/pattern_conversion.cpp | 2 +- .../pattern_encapsulation.cpp | 2 +- torch/csrc/jit/passes/onnx/peephole.cpp | 8 +- .../jit/passes/onnx/scalar_type_analysis.cpp | 16 ++-- .../jit/passes/onnx/shape_type_inference.cpp | 26 +++--- .../passes/onnx/unpack_quantized_weights.cpp | 4 +- .../csrc/jit/passes/peephole_dict_idioms.cpp | 16 ++-- .../csrc/jit/passes/peephole_list_idioms.cpp | 8 +- torch/csrc/jit/passes/quantization/helper.cpp | 12 +-- torch/csrc/jit/passes/quantization/helper.h | 2 +- .../passes/quantization/insert_observers.cpp | 4 +- .../quantization/insert_quant_dequant.cpp | 8 +- torch/csrc/jit/passes/remove_mutation.h | 4 +- .../passes/replacement_of_old_operators.cpp | 2 +- torch/csrc/jit/passes/shape_analysis.cpp | 42 +++++----- .../jit/passes/symbolic_shape_analysis.cpp | 34 ++++---- .../csrc/jit/passes/symbolic_shape_cache.cpp | 4 +- .../passes/symbolic_shape_runtime_fusion.cpp | 2 +- torch/csrc/jit/passes/tensorexpr_fuser.cpp | 4 +- .../passes/utils/check_alias_annotation.cpp | 6 +- torch/csrc/jit/passes/utils/memory_dag.h | 2 +- .../csrc/jit/passes/utils/subgraph_utils.cpp | 2 +- torch/csrc/jit/python/init.cpp | 10 +-- torch/csrc/jit/python/module_python.h | 4 +- torch/csrc/jit/python/pybind_utils.cpp | 4 +- torch/csrc/jit/python/pybind_utils.h | 14 ++-- torch/csrc/jit/python/python_ir.cpp | 6 +- torch/csrc/jit/python/python_ivalue.h | 2 +- torch/csrc/jit/python/python_list.h | 4 +- .../csrc/jit/python/python_sugared_value.cpp | 18 ++-- torch/csrc/jit/python/python_sugared_value.h | 2 +- torch/csrc/jit/python/python_tree_views.cpp | 4 +- torch/csrc/jit/python/script_init.cpp | 12 +-- torch/csrc/jit/runtime/autodiff.cpp | 4 +- .../jit/runtime/decomposition_registry.cpp | 6 +- torch/csrc/jit/runtime/graph_executor.h | 2 +- torch/csrc/jit/runtime/graph_executor_impl.h | 2 +- torch/csrc/jit/runtime/interpreter.cpp | 6 +- torch/csrc/jit/runtime/interpreter.h | 6 +- torch/csrc/jit/runtime/jit_exception.h | 6 +- torch/csrc/jit/runtime/operator.h | 2 +- .../runtime/profiling_graph_executor_impl.cpp | 10 +-- torch/csrc/jit/runtime/register_ops_utils.h | 2 +- torch/csrc/jit/runtime/register_prim_ops.cpp | 8 +- .../jit/runtime/register_prim_ops_fulljit.cpp | 16 ++-- .../csrc/jit/runtime/register_special_ops.cpp | 8 +- .../runtime/simple_graph_executor_impl.cpp | 2 +- torch/csrc/jit/runtime/static/fusion.cpp | 4 +- torch/csrc/jit/runtime/static/impl.cpp | 8 +- torch/csrc/jit/runtime/static/ops.cpp | 68 +++++++-------- torch/csrc/jit/runtime/static/ops.h | 24 +++--- torch/csrc/jit/runtime/symbolic_script.cpp | 4 +- torch/csrc/jit/runtime/symbolic_script.h | 2 +- .../jit/runtime/symbolic_shape_registry.cpp | 6 +- .../callstack_debug_info_serialization.cpp | 2 +- torch/csrc/jit/serialization/export.cpp | 2 +- .../jit/serialization/export_bytecode.cpp | 2 +- .../csrc/jit/serialization/export_module.cpp | 4 +- .../serialization/flatbuffer_serializer.cpp | 4 +- torch/csrc/jit/serialization/import.h | 18 ++-- .../csrc/jit/serialization/import_source.cpp | 4 +- torch/csrc/jit/serialization/import_source.h | 4 +- torch/csrc/jit/serialization/pickle.cpp | 6 +- torch/csrc/jit/serialization/pickler.cpp | 2 +- torch/csrc/jit/serialization/python_print.cpp | 2 +- .../source_range_serialization.cpp | 4 +- torch/csrc/jit/tensorexpr/codegen.cpp | 2 +- torch/csrc/jit/tensorexpr/eval.cpp | 2 +- torch/csrc/jit/tensorexpr/expr.h | 14 ++-- .../jit/tensorexpr/external_functions.cpp | 16 ++-- .../csrc/jit/tensorexpr/external_functions.h | 4 +- torch/csrc/jit/tensorexpr/graph_opt.cpp | 4 +- torch/csrc/jit/tensorexpr/ir.h | 2 +- torch/csrc/jit/tensorexpr/ir_simplifier.cpp | 12 +-- torch/csrc/jit/tensorexpr/kernel.cpp | 16 ++-- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 6 +- torch/csrc/jit/tensorexpr/llvm_codegen.h | 14 ++-- torch/csrc/jit/tensorexpr/llvm_jit.h | 2 +- .../csrc/jit/tensorexpr/operators/conv2d.cpp | 4 +- torch/csrc/jit/tensorexpr/operators/misc.cpp | 2 +- .../csrc/jit/tensorexpr/operators/pointwise.h | 2 +- .../jit/tensorexpr/operators/quantization.cpp | 4 +- .../csrc/jit/tensorexpr/operators/softmax.cpp | 12 +-- torch/csrc/jit/tensorexpr/tensor.cpp | 24 +++--- torch/csrc/jit/tensorexpr/tensor.h | 8 +- torch/csrc/jit/testing/file_check.cpp | 10 +-- torch/csrc/lazy/backend/backend_device.cpp | 12 +-- torch/csrc/lazy/backend/backend_device.h | 2 +- torch/csrc/lazy/core/ir_builder.h | 6 +- torch/csrc/lazy/core/ir_dump_util.cpp | 6 +- torch/csrc/lazy/core/lazy_graph_executor.cpp | 2 +- torch/csrc/lazy/core/shape.cpp | 4 +- torch/csrc/lazy/core/shape.h | 4 +- torch/csrc/lazy/core/shape_inference.h | 2 +- torch/csrc/lazy/core/tensor.cpp | 8 +- torch/csrc/lazy/core/unique.h | 2 +- torch/csrc/lazy/core/util.h | 4 +- torch/csrc/lazy/python/python_util.cpp | 4 +- torch/csrc/lazy/python/python_util.h | 2 +- torch/csrc/lazy/ts_backend/ir_builder.h | 2 +- .../lazy/ts_backend/ts_eager_fallback.cpp | 2 +- .../lazy/ts_backend/ts_native_functions.cpp | 8 +- torch/csrc/profiler/collection.cpp | 6 +- torch/csrc/profiler/collection.h | 2 +- torch/csrc/profiler/python/init.cpp | 2 +- torch/csrc/profiler/unwind/unwind.cpp | 4 +- torch/csrc/profiler/unwind/unwind.h | 2 +- torch/csrc/profiler/unwind/unwind_error.h | 2 +- torch/csrc/profiler/util.h | 2 +- torch/csrc/tensor/python_tensor.cpp | 2 +- torch/csrc/utils/nested.cpp | 2 +- torch/csrc/utils/python_arg_parser.cpp | 6 +- torch/csrc/utils/python_arg_parser.h | 28 +++---- torch/csrc/utils/python_dispatch.cpp | 24 +++--- torch/csrc/utils/python_raii.h | 6 +- torch/csrc/utils/python_symnode.h | 2 +- torch/csrc/utils/schema_info.cpp | 4 +- torch/csrc/utils/tensor_new.cpp | 20 ++--- torch/csrc/utils/torch_dispatch_mode.h | 2 +- torch/custom_class_detail.h | 2 +- torch/library.h | 14 ++-- 330 files changed, 1207 insertions(+), 1208 deletions(-) diff --git a/aten/src/ATen/CPUGeneratorImpl.h b/aten/src/ATen/CPUGeneratorImpl.h index e15ca23d6bf748..34dd33a475b917 100644 --- a/aten/src/ATen/CPUGeneratorImpl.h +++ b/aten/src/ATen/CPUGeneratorImpl.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include namespace at { diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 4bf820312d2f68..411cf12d513414 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp index 6837348c932101..7aa9b0f02ea365 100644 --- a/aten/src/ATen/SavedTensorHooks.cpp +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -32,7 +32,7 @@ void SavedTensorDefaultHooks::disable(const std::string& message) { } void SavedTensorDefaultHooks::enable() { - tls.disabled_error_message = std::nullopt; + tls.disabled_error_message = c10::nullopt; } /* static */ bool SavedTensorDefaultHooks::set_tracing(bool is_tracing) { diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h index 9cf1ea37c35390..b69b9c25e8e6a5 100644 --- a/aten/src/ATen/SavedTensorHooks.h +++ b/aten/src/ATen/SavedTensorHooks.h @@ -1,8 +1,8 @@ #pragma once #include +#include #include -#include #include #include diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 1fe9e7ebdcb012..eb36c0e02fa4db 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -5,8 +5,8 @@ #include #include #include +#include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 58d46aacd47314..c8402640aa08ac 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 63fbcb55e96d2b..014260fb220f89 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include #include #include diff --git a/c10/core/ConstantSymNodeImpl.h b/c10/core/ConstantSymNodeImpl.h index 791a81cace4176..3c0fb66f7469fe 100644 --- a/c10/core/ConstantSymNodeImpl.h +++ b/c10/core/ConstantSymNodeImpl.h @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include #include #include @@ -73,14 +73,14 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl { if constexpr (is_int_()) { return ::std::get(value_); } else { - return std::nullopt; + return c10::nullopt; } } std::optional constant_bool() override { if constexpr (is_bool_()) { return ::std::get(value_); } else { - return std::nullopt; + return c10::nullopt; } } bool is_constant() override { diff --git a/c10/core/ScalarTypeToTypeMeta.h b/c10/core/ScalarTypeToTypeMeta.h index 5e9e1a936af5af..d2694c96221eb4 100644 --- a/c10/core/ScalarTypeToTypeMeta.h +++ b/c10/core/ScalarTypeToTypeMeta.h @@ -30,7 +30,7 @@ inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { inline optional optTypeMetaToScalarType( optional type_meta) { if (!type_meta.has_value()) { - return std::nullopt; + return c10::nullopt; } return type_meta->toScalarType(); } diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index 06ce32c1a7160b..9f9f141293a375 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -3,9 +3,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -68,7 +68,7 @@ class C10_API SymBool { std::optional maybe_as_bool() const { if (!is_heap_allocated()) { - return std::make_optional(data_); + return c10::make_optional(data_); } return toSymNodeImplUnowned()->constant_bool(); } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index eef34aac24ca6d..025c351334a016 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -231,7 +231,7 @@ class C10_API SymInt { std::optional maybe_as_int() const { if (!is_heap_allocated()) { - return std::make_optional(data_); + return c10::make_optional(data_); } auto* node = toSymNodeImplUnowned(); if (auto c = node->constant_int()) { diff --git a/c10/core/SymIntArrayRef.h b/c10/core/SymIntArrayRef.h index ce7253c60ec59f..760f4ba4e79a21 100644 --- a/c10/core/SymIntArrayRef.h +++ b/c10/core/SymIntArrayRef.h @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include namespace c10 { using SymIntArrayRef = ArrayRef; @@ -23,7 +23,7 @@ inline std::optional asIntArrayRefSlowOpt( c10::SymIntArrayRef ar) { for (const c10::SymInt& sci : ar) { if (sci.is_heap_allocated()) { - return std::nullopt; + return c10::nullopt; } } diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 39e4bbbc2c6cd7..bb92b09775b7b4 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -3,9 +3,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -207,19 +207,19 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { TORCH_CHECK(false, "NYI"); }; virtual std::optional nested_int() { - return std::nullopt; + return c10::nullopt; } virtual std::optional nested_int_coeff() { - return std::nullopt; + return c10::nullopt; } virtual std::optional constant_int() { - return std::nullopt; + return c10::nullopt; } virtual std::optional constant_bool() { - return std::nullopt; + return c10::nullopt; } virtual std::optional maybe_as_int() { - return std::nullopt; + return c10::nullopt; } virtual bool is_constant() { return false; diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index b59a95a4a2faf4..62b03d36ec71c9 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -56,7 +56,7 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { // Couldn't find. Tell the caller to do the normal computation // Alternately, if everything is hinted, we want the normal computation // too - return std::nullopt; + return c10::nullopt; } // Populate the SymNode array std::vector size_nodes; @@ -69,7 +69,7 @@ normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { for (const auto& s : strides) { stride_nodes.emplace_back(s.wrap_node(base)); } - return std::make_optional( + return c10::make_optional( std::tuple, std::vector>( std::move(base), std::move(size_nodes), std::move(stride_nodes))); } diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 130292aaa70d6a..516a61f0200462 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -8,9 +8,9 @@ #include #include #include +#include #include #include -#include #include diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 67543614c021bc..877c1c09543cb5 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -24,12 +24,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include @@ -233,8 +233,8 @@ struct C10_API ExtraMeta { std::unique_ptr symbolic_shape_meta_ = nullptr; std::unique_ptr named_tensor_meta_ = nullptr; intrusive_ptr backend_meta_ = nullptr; - std::optional custom_data_ptr_error_msg_ = std::nullopt; - std::optional custom_storage_error_msg_ = std::nullopt; + std::optional custom_data_ptr_error_msg_ = c10::nullopt; + std::optional custom_storage_error_msg_ = c10::nullopt; ExtraMeta() = default; ExtraMeta(const ExtraMeta& other) { @@ -260,8 +260,8 @@ struct C10_API ExtraMeta { std::unique_ptr symbolic_shape_meta, std::unique_ptr named_tensor_meta, intrusive_ptr backend_meta, - std::optional custom_data_ptr_error_msg = std::nullopt, - std::optional custom_storage_access_error_msg = std::nullopt) + std::optional custom_data_ptr_error_msg = c10::nullopt, + std::optional custom_storage_access_error_msg = c10::nullopt) : symbolic_shape_meta_(std::move(symbolic_shape_meta)), named_tensor_meta_(std::move(named_tensor_meta)), backend_meta_(std::move(backend_meta)), @@ -1737,7 +1737,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_sizes_and_strides( c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, - std::optional storage_offset = std::nullopt); + std::optional storage_offset = c10::nullopt); // This is renamed to avoid breaking overload BC void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes); void generic_set_sizes_contiguous(c10::IntArrayRef sizes) { @@ -1834,7 +1834,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { void set_sizes_and_strides( IntArrayRef new_size, IntArrayRef new_stride, - std::optional storage_offset = std::nullopt) { + std::optional storage_offset = c10::nullopt) { TORCH_CHECK( allow_tensor_metadata_change(), "set_sizes_and_strides ", diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 9c23c767ffc5ef..d99005d3d28f85 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include @@ -284,10 +284,10 @@ struct C10_API TensorOptions { return has_device_; } - /// Returns the device of the `TensorOptions`, or `std::nullopt` if + /// Returns the device of the `TensorOptions`, or `c10::nullopt` if /// device is not specified. std::optional device_opt() const noexcept { - return has_device_ ? std::make_optional(device_) : std::nullopt; + return has_device_ ? c10::make_optional(device_) : c10::nullopt; } /// Returns the device index of the `TensorOptions`. @@ -305,10 +305,10 @@ struct C10_API TensorOptions { return has_dtype_; } - /// Returns the dtype of the `TensorOptions`, or `std::nullopt` if + /// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if /// device is not specified. std::optional dtype_opt() const noexcept { - return has_dtype_ ? std::make_optional(dtype_) : std::nullopt; + return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt; } /// Returns the layout of the `TensorOptions`. @@ -321,10 +321,10 @@ struct C10_API TensorOptions { return has_layout_; } - /// Returns the layout of the `TensorOptions`, or `std::nullopt` if + /// Returns the layout of the `TensorOptions`, or `c10::nullopt` if /// layout is not specified. std::optional layout_opt() const noexcept { - return has_layout_ ? std::make_optional(layout_) : std::nullopt; + return has_layout_ ? c10::make_optional(layout_) : c10::nullopt; } /// Returns the `requires_grad` property of the `TensorOptions`. @@ -338,10 +338,10 @@ struct C10_API TensorOptions { } /// Returns the `requires_grad` property of the `TensorOptions`, or - /// `std::nullopt` if `requires_grad` is not specified. + /// `c10::nullopt` if `requires_grad` is not specified. std::optional requires_grad_opt() const noexcept { - return has_requires_grad_ ? std::make_optional(requires_grad_) - : std::nullopt; + return has_requires_grad_ ? c10::make_optional(requires_grad_) + : c10::nullopt; } /// Returns the `pinned_memory` property of the `TensorOptions`. @@ -378,10 +378,10 @@ struct C10_API TensorOptions { } /// Returns the `pinned_memory` property of the `TensorOptions`, or - /// `std::nullopt` if `pinned_memory` is not specified. + /// `c10::nullopt` if `pinned_memory` is not specified. std::optional pinned_memory_opt() const noexcept { - return has_pinned_memory_ ? std::make_optional(pinned_memory_) - : std::nullopt; + return has_pinned_memory_ ? c10::make_optional(pinned_memory_) + : c10::nullopt; } /// Returns whether the `memory_layout` is specified @@ -393,10 +393,10 @@ struct C10_API TensorOptions { // behavior of memory_format varies from function to function. /// Returns the `memory_layout` property of `TensorOptions, or - /// `std::nullopt` if `memory_format` is not specified. + /// `c10::nullopt` if `memory_format` is not specified. std::optional memory_format_opt() const noexcept { - return has_memory_format_ ? std::make_optional(memory_format_) - : std::nullopt; + return has_memory_format_ ? c10::make_optional(memory_format_) + : c10::nullopt; } // Resolves the ATen backend specified by the current construction axes. diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index 2a715d78bdb767..1b16a5d5b9fd7e 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -5,7 +5,7 @@ namespace c10 { // should this use the globalContext? Can it get a context passed in somehow? UndefinedTensorImpl::UndefinedTensorImpl() - : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), std::nullopt) { + : TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) { set_storage_access_should_throw(); // TODO: accessing the sizes on an undefined tensor is not meaningful // and should error too, but empirically it does not! diff --git a/c10/core/impl/InlineDeviceGuard.h b/c10/core/impl/InlineDeviceGuard.h index a70c194efccf8c..3e9f91eff61700 100644 --- a/c10/core/impl/InlineDeviceGuard.h +++ b/c10/core/impl/InlineDeviceGuard.h @@ -404,7 +404,7 @@ class InlineOptionalDeviceGuard { /// Returns the device that was set immediately prior to initialization of /// the, guard, or nullopt if the guard is uninitialized. optional original_device() const { - return guard_.has_value() ? std::make_optional(guard_->original_device()) + return guard_.has_value() ? make_optional(guard_->original_device()) : nullopt; } @@ -412,7 +412,7 @@ class InlineOptionalDeviceGuard { /// either from construction, or via set_device, if the guard is initialized, /// or nullopt if the guard is uninitialized. optional current_device() const { - return guard_.has_value() ? std::make_optional(guard_->current_device()) + return guard_.has_value() ? make_optional(guard_->current_device()) : nullopt; } diff --git a/c10/core/impl/InlineStreamGuard.h b/c10/core/impl/InlineStreamGuard.h index 5ac913c4ff7fff..b99e7db72addc6 100644 --- a/c10/core/impl/InlineStreamGuard.h +++ b/c10/core/impl/InlineStreamGuard.h @@ -173,7 +173,7 @@ class InlineOptionalStreamGuard { /// Returns the stream that was set at the time the guard was most recently /// initialized, or nullopt if the guard is uninitialized. optional original_stream() const { - return guard_.has_value() ? std::make_optional(guard_->original_stream()) + return guard_.has_value() ? make_optional(guard_->original_stream()) : nullopt; } @@ -181,7 +181,7 @@ class InlineOptionalStreamGuard { /// either from construction, or via reset_stream, if the guard is /// initialized, or nullopt if the guard is uninitialized. optional current_stream() const { - return guard_.has_value() ? std::make_optional(guard_->current_stream()) + return guard_.has_value() ? make_optional(guard_->current_stream()) : nullopt; } diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 8f2833b5c7da41..518b0e63e49217 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,8 +2,8 @@ #include #include +#include #include -#include #include @@ -106,13 +106,13 @@ struct C10_API PyObjectSlot { // after we query here. The only time when we can conclude a tensor // is definitely uninitialized is when we have just allocated it and // it cannot have escaped to other threads yet - return std::nullopt; + return c10::nullopt; } else if (interpreter == self_interpreter) { // NB: pyobj_ could still be null! if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return std::nullopt; + return c10::nullopt; } else { - return std::make_optional(_unchecked_untagged_pyobj()); + return c10::make_optional(_unchecked_untagged_pyobj()); } } else { TORCH_CHECK( diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index c9a3274ed896c3..f1847cb005b4ce 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -16,7 +16,7 @@ bool TorchDispatchModeTLS::any_modes_set(bool skip_infra_modes) { if (!skip_infra_modes) { for (const auto i : c10::irange( static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { - if (torchDispatchModeState.infra_modes_[i] != std::nullopt) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { return true; } } @@ -48,7 +48,7 @@ const std::shared_ptr TorchDispatchModeTLS:: if (torchDispatchModeState.infra_modes_[i].has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) out = std::move(torchDispatchModeState.infra_modes_[i].value()); - torchDispatchModeState.infra_modes_[i] = std::nullopt; + torchDispatchModeState.infra_modes_[i] = c10::nullopt; break; } } @@ -70,7 +70,7 @@ const std:: if (torchDispatchModeState.infra_modes_[i].has_value()) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) auto out_mode = torchDispatchModeState.infra_modes_[i].value(); - torchDispatchModeState.infra_modes_[i] = std::nullopt; + torchDispatchModeState.infra_modes_[i] = c10::nullopt; if (!any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( @@ -114,7 +114,7 @@ int64_t TorchDispatchModeTLS::stack_len() { int64_t infra_modes_len = 0; for (const auto i : c10::irange(static_cast(TorchDispatchModeKey::NUM_MODE_KEYS))) { - if (torchDispatchModeState.infra_modes_[i] != std::nullopt) { + if (torchDispatchModeState.infra_modes_[i] != c10::nullopt) { infra_modes_len += 1; } } @@ -131,7 +131,7 @@ void TorchDispatchModeTLS::set_mode( TorchDispatchModeKey mode_key) { TORCH_CHECK( torchDispatchModeState.infra_modes_[static_cast(mode_key)] == - std::nullopt, + c10::nullopt, "trying to set the current ", to_string(mode_key), ", but one already exists"); @@ -149,7 +149,7 @@ const std::optional> TorchDispatchModeTLS::unset_mode(TorchDispatchModeKey mode_key) { auto out = torchDispatchModeState.infra_modes_[static_cast(mode_key)]; torchDispatchModeState.infra_modes_[static_cast(mode_key)] = - std::nullopt; + c10::nullopt; if (out.has_value() && !any_modes_set()) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index e4535292cebaac..11bea6056e9d85 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -411,7 +411,7 @@ struct ExpandableSegment { return rangeFromHandles(begin, end); } while (end > handles_.size()) { - handles_.emplace_back(std::nullopt); + handles_.emplace_back(c10::nullopt); } for (auto i : c10::irange(begin, end)) { TORCH_INTERNAL_ASSERT(!handles_.at(i)); @@ -426,7 +426,7 @@ struct ExpandableSegment { if (status == CUDA_ERROR_OUT_OF_MEMORY) { for (auto j : c10::irange(begin, i)) { auto h = handles_.at(j).value(); - handles_.at(j) = std::nullopt; + handles_.at(j) = c10::nullopt; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h)); } trimHandles(); @@ -507,7 +507,7 @@ struct ExpandableSegment { C10_CUDA_CHECK(cudaStreamSynchronize(stream_)); for (auto i : c10::irange(begin, end)) { CUmemGenericAllocationHandle h = handles_.at(i).value(); - handles_.at(i) = std::nullopt; + handles_.at(i) = c10::nullopt; C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_( ptr_ + segment_size_ * i, segment_size_)); C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h)); diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 8d88000b89db94..2b53eb4d7c7cb7 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -166,7 +166,7 @@ std::optional getDeviceIndexWithPrimaryContext() { return device_index; } } - return std::nullopt; + return c10::nullopt; } namespace _internal { diff --git a/c10/cuda/CUDAGuard.h b/c10/cuda/CUDAGuard.h index 65f5c5d191b7fb..254522893d5e08 100644 --- a/c10/cuda/CUDAGuard.h +++ b/c10/cuda/CUDAGuard.h @@ -242,7 +242,7 @@ struct OptionalCUDAStreamGuard { optional original_stream() const { auto r = guard_.original_stream(); if (r.has_value()) { - return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); } else { return nullopt; } @@ -254,7 +254,7 @@ struct OptionalCUDAStreamGuard { optional current_stream() const { auto r = guard_.current_stream(); if (r.has_value()) { - return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); + return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); } else { return nullopt; } diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 1ef2fcb2c08f4d..ec50c8152b33e2 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -14,9 +14,9 @@ #include #include #include +#include #include #include -#include namespace c10::cuda::impl { @@ -45,7 +45,7 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); C10_CUDA_CHECK_WARN(err); if (err != cudaSuccess) { - return std::nullopt; + return c10::nullopt; } return Device(DeviceType::CUDA, device); } diff --git a/c10/test/core/DeviceGuard_test.cpp b/c10/test/core/DeviceGuard_test.cpp index 0869ea1168d167..63049ae7b555a2 100644 --- a/c10/test/core/DeviceGuard_test.cpp +++ b/c10/test/core/DeviceGuard_test.cpp @@ -36,7 +36,6 @@ TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) { g.reset_device(Device(DeviceType::HIP, 2), &hip_impl); ASSERT_EQ(FakeGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(FakeGuardImpl::getDeviceIndex(), 2); - ASSERT_EQ(g.current_device(), std::make_optional(Device(DeviceType::HIP, 2))); - ASSERT_EQ( - g.original_device(), std::make_optional(Device(DeviceType::HIP, 0))); + ASSERT_EQ(g.current_device(), make_optional(Device(DeviceType::HIP, 2))); + ASSERT_EQ(g.original_device(), make_optional(Device(DeviceType::HIP, 0))); } diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index 7cefa1e4a771bf..8055ec7a325111 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -8,7 +8,7 @@ using namespace c10; #ifndef C10_MOBILE static void check(int64_t value) { const auto i = SymInt(value); - EXPECT_EQ(i.maybe_as_int(), std::make_optional(value)); + EXPECT_EQ(i.maybe_as_int(), c10::make_optional(value)); } TEST(SymIntTest, ConcreteInts) { diff --git a/c10/test/core/impl/InlineDeviceGuard_test.cpp b/c10/test/core/impl/InlineDeviceGuard_test.cpp index 2b4ad0c5b2381f..69db93e307bfe8 100644 --- a/c10/test/core/impl/InlineDeviceGuard_test.cpp +++ b/c10/test/core/impl/InlineDeviceGuard_test.cpp @@ -170,12 +170,12 @@ TEST(InlineOptionalDeviceGuard, SetDevice) { MaybeTestGuard g; DeviceIndex i = 1; g.set_device(dev(i)); - ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); + ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); g.set_device(dev(i)); - ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); + ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); } @@ -185,11 +185,11 @@ TEST(InlineOptionalDeviceGuard, SetIndex) { DeviceIndex i = 1; MaybeTestGuard g; g.set_index(i); - ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); + ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); g.set_index(i); - ASSERT_EQ(g.original_device(), std::make_optional(dev(init_i))); - ASSERT_EQ(g.current_device(), std::make_optional(dev(i))); + ASSERT_EQ(g.original_device(), make_optional(dev(init_i))); + ASSERT_EQ(g.current_device(), make_optional(dev(i))); ASSERT_EQ(TestGuardImpl::getDeviceIndex(), i); } diff --git a/c10/test/core/impl/InlineStreamGuard_test.cpp b/c10/test/core/impl/InlineStreamGuard_test.cpp index 06c4b96ef913ef..692504cebd1ccc 100644 --- a/c10/test/core/impl/InlineStreamGuard_test.cpp +++ b/c10/test/core/impl/InlineStreamGuard_test.cpp @@ -109,8 +109,8 @@ TEST(InlineOptionalStreamGuard, Constructor) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2))); + ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), make_optional(stream(1, 2))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -120,8 +120,8 @@ TEST(InlineOptionalStreamGuard, Constructor) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 2); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 2))); + ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), make_optional(stream(1, 2))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -146,8 +146,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamSameDevice) { ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 1); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 3); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), std::make_optional(stream(1, 3))); + ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), make_optional(stream(1, 3))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); @@ -164,8 +164,8 @@ TEST(InlineOptionalStreamGuard, ResetStreamDifferentDevice) { ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 3); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(1), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(0), 0); - ASSERT_EQ(g.original_stream(), std::make_optional(stream(0, 0))); - ASSERT_EQ(g.current_stream(), std::make_optional(stream(2, 3))); + ASSERT_EQ(g.original_stream(), make_optional(stream(0, 0))); + ASSERT_EQ(g.current_stream(), make_optional(stream(2, 3))); } ASSERT_EQ(TestGuardImpl::getDeviceIndex(), 0); ASSERT_EQ(TestGuardImpl::getCurrentStreamIdFor(2), 0); diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp index e9496d9dc2887e..aa4c5a527ce667 100644 --- a/c10/test/util/optional_test.cpp +++ b/c10/test/util/optional_test.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -67,7 +67,7 @@ TYPED_TEST(OptionalTest, Empty) { EXPECT_FALSE(empty.has_value()); // NOLINTNEXTLINE(bugprone-unchecked-optional-access,hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW(empty.value(), std::bad_optional_access); + EXPECT_THROW(empty.value(), c10::bad_optional_access); } TYPED_TEST(OptionalTest, Initialized) { @@ -111,32 +111,32 @@ TEST_P(SelfCompareTest, SelfCompare) { INSTANTIATE_TEST_SUITE_P( nullopt, SelfCompareTest, - testing::Values(std::nullopt)); + testing::Values(c10::nullopt)); INSTANTIATE_TEST_SUITE_P( int, SelfCompareTest, - testing::Values(std::make_optional(2))); + testing::Values(c10::make_optional(2))); TEST(OptionalTest, Nullopt) { std::optional x = 2; - EXPECT_THAT(std::nullopt, Not(Eq(x))); - EXPECT_THAT(x, Not(Eq(std::nullopt))); + EXPECT_THAT(c10::nullopt, Not(Eq(x))); + EXPECT_THAT(x, Not(Eq(c10::nullopt))); - EXPECT_THAT(x, Ne(std::nullopt)); - EXPECT_THAT(std::nullopt, Ne(x)); + EXPECT_THAT(x, Ne(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Ne(x)); - EXPECT_THAT(x, Not(Lt(std::nullopt))); - EXPECT_THAT(std::nullopt, Lt(x)); + EXPECT_THAT(x, Not(Lt(c10::nullopt))); + EXPECT_THAT(c10::nullopt, Lt(x)); - EXPECT_THAT(x, Not(Le(std::nullopt))); - EXPECT_THAT(std::nullopt, Le(x)); + EXPECT_THAT(x, Not(Le(c10::nullopt))); + EXPECT_THAT(c10::nullopt, Le(x)); - EXPECT_THAT(x, Gt(std::nullopt)); - EXPECT_THAT(std::nullopt, Not(Gt(x))); + EXPECT_THAT(x, Gt(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Not(Gt(x))); - EXPECT_THAT(x, Ge(std::nullopt)); - EXPECT_THAT(std::nullopt, Not(Ge(x))); + EXPECT_THAT(x, Ge(c10::nullopt)); + EXPECT_THAT(c10::nullopt, Not(Ge(x))); } // Ensure comparisons work... diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index d461267000befc..7d0fedbb335a29 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -1,7 +1,7 @@ #include +#include #include #include -#include #include #include @@ -150,19 +150,19 @@ std::optional parse_frame_information( auto function_name_start = frame_string.find('('); if (function_name_start == std::string::npos) { - return std::nullopt; + return c10::nullopt; } function_name_start += 1; auto offset_start = frame_string.find('+', function_name_start); if (offset_start == std::string::npos) { - return std::nullopt; + return c10::nullopt; } offset_start += 1; const auto offset_end = frame_string.find(')', offset_start); if (offset_end == std::string::npos) { - return std::nullopt; + return c10::nullopt; } frame.object_file = frame_string.substr(0, function_name_start - 1); @@ -186,7 +186,7 @@ std::optional parse_frame_information( skip >> frame.offset_into_function; #else #warning Unknown standard library, backtraces may have incomplete debug information - return std::nullopt; + return c10::nullopt; #endif // defined(__GLIBCXX__) // Some system-level functions don't have sufficient debug information, so diff --git a/c10/util/OptionalArrayRef.h b/c10/util/OptionalArrayRef.h index ae4f4f1f2c67bd..98237bba92f56d 100644 --- a/c10/util/OptionalArrayRef.h +++ b/c10/util/OptionalArrayRef.h @@ -12,9 +12,9 @@ #pragma once #include +#include #include #include -#include #include #include @@ -27,16 +27,16 @@ class OptionalArrayRef final { constexpr OptionalArrayRef() noexcept = default; - constexpr OptionalArrayRef(std::nullopt_t) noexcept {} + constexpr OptionalArrayRef(nullopt_t) noexcept {} OptionalArrayRef(const OptionalArrayRef& other) = default; OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; - constexpr OptionalArrayRef(const std::optional>& other) noexcept + constexpr OptionalArrayRef(const optional>& other) noexcept : wrapped_opt_array_ref(other) {} - constexpr OptionalArrayRef(std::optional>&& other) noexcept + constexpr OptionalArrayRef(optional>&& other) noexcept : wrapped_opt_array_ref(std::move(other)) {} constexpr OptionalArrayRef(const T& value) noexcept @@ -89,8 +89,8 @@ class OptionalArrayRef final { // Assignment - constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept { - wrapped_opt_array_ref = std::nullopt; + constexpr OptionalArrayRef& operator=(nullopt_t) noexcept { + wrapped_opt_array_ref = c10::nullopt; return *this; } @@ -99,13 +99,13 @@ class OptionalArrayRef final { OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; constexpr OptionalArrayRef& operator=( - const std::optional>& other) noexcept { + const optional>& other) noexcept { wrapped_opt_array_ref = other; return *this; } constexpr OptionalArrayRef& operator=( - std::optional>&& other) noexcept { + optional>&& other) noexcept { wrapped_opt_array_ref = std::move(other); return *this; } @@ -213,7 +213,7 @@ class OptionalArrayRef final { } private: - std::optional> wrapped_opt_array_ref; + optional> wrapped_opt_array_ref; }; using OptionalIntArrayRef = OptionalArrayRef; diff --git a/c10/xpu/test/impl/XPUStreamTest.cpp b/c10/xpu/test/impl/XPUStreamTest.cpp index 6cbe3ae6721587..01a1dbb62621b2 100644 --- a/c10/xpu/test/impl/XPUStreamTest.cpp +++ b/c10/xpu/test/impl/XPUStreamTest.cpp @@ -1,9 +1,9 @@ #include +#include #include #include #include -#include #include #include diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 96788c5d79f376..00a2c0bbe30267 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,8 +1,8 @@ #include +#include #include #include #include -#include #ifndef _MSC_VER #include @@ -1817,7 +1817,7 @@ Call this whenever a new thread is created in order to propagate values from transposed_, output_padding_, std::move(groups_), - std::nullopt); + c10::nullopt); }, py::arg("input"), py::arg("weight"), @@ -1842,7 +1842,7 @@ Call this whenever a new thread is created in order to propagate values from at::SymIntArrayRef output_padding_, c10::SymInt groups_, std::optional> bias_sizes_opt) { - c10::OptionalArrayRef ref = std::nullopt; + c10::OptionalArrayRef ref = c10::nullopt; if (bias_sizes_opt) { ref = (*bias_sizes_opt); } @@ -2031,7 +2031,7 @@ Call this whenever a new thread is created in order to propagate values from py_module.def( "_get_accelerator", - [](std::optional check = std::nullopt) { + [](std::optional check = c10::nullopt) { return c10::Device( at::getAccelerator(check.value_or(false)) .value_or(c10::DeviceType::CPU), diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 77520b6f1cdb1f..aa5584abd39e4c 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -153,7 +153,7 @@ static bool THPStorage_isPreservable(THPStorage* self) { if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/true) != - std::make_optional((PyObject*)self)) { + c10::make_optional((PyObject*)self)) { return false; } if (storage.use_count() <= 1) { diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h index 62c12d2e0ac8b4..f0901b06af68cb 100644 --- a/torch/csrc/api/include/torch/expanding_array.h +++ b/torch/csrc/api/include/torch/expanding_array.h @@ -2,8 +2,8 @@ #include #include +#include #include -#include #include #include diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index ef6d9b1bc23620..d9a3430a7a2496 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -15,9 +15,9 @@ namespace fft { /// ``` inline Tensor fft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_fft_symint(self, n, dim, norm); } @@ -31,9 +31,9 @@ inline Tensor fft( /// ``` inline Tensor ifft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_ifft_symint(self, n, dim, norm); } @@ -47,9 +47,9 @@ inline Tensor ifft( /// ``` inline Tensor fft2( const Tensor& self, - OptionalIntArrayRef s = std::nullopt, + OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_fft2(self, s, dim, norm); } @@ -63,9 +63,9 @@ inline Tensor fft2( /// ``` inline Tensor ifft2( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_ifft2(self, s, dim, norm); } @@ -79,9 +79,9 @@ inline Tensor ifft2( /// ``` inline Tensor fftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, - at::OptionalIntArrayRef dim = std::nullopt, - std::optional norm = std::nullopt) { + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + std::optional norm = c10::nullopt) { return torch::fft_fftn(self, s, dim, norm); } @@ -95,9 +95,9 @@ inline Tensor fftn( /// ``` inline Tensor ifftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, - at::OptionalIntArrayRef dim = std::nullopt, - std::optional norm = std::nullopt) { + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + std::optional norm = c10::nullopt) { return torch::fft_ifftn(self, s, dim, norm); } @@ -112,9 +112,9 @@ inline Tensor ifftn( /// ``` inline Tensor rfft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_rfft_symint(self, n, dim, norm); } @@ -131,9 +131,9 @@ inline Tensor rfft( /// ``` inline Tensor irfft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_irfft_symint(self, n, dim, norm); } @@ -147,9 +147,9 @@ inline Tensor irfft( /// ``` inline Tensor rfft2( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_rfft2(self, s, dim, norm); } @@ -163,9 +163,9 @@ inline Tensor rfft2( /// ``` inline Tensor irfft2( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_irfft2(self, s, dim, norm); } @@ -179,9 +179,9 @@ inline Tensor irfft2( /// ``` inline Tensor rfftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, - at::OptionalIntArrayRef dim = std::nullopt, - std::optional norm = std::nullopt) { + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + std::optional norm = c10::nullopt) { return torch::fft_rfftn(self, s, dim, norm); } @@ -195,9 +195,9 @@ inline Tensor rfftn( /// ``` inline Tensor irfftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, - at::OptionalIntArrayRef dim = std::nullopt, - std::optional norm = std::nullopt) { + at::OptionalIntArrayRef s = c10::nullopt, + at::OptionalIntArrayRef dim = c10::nullopt, + std::optional norm = c10::nullopt) { return torch::fft_irfftn(self, s, dim, norm); } @@ -215,9 +215,9 @@ inline Tensor irfftn( /// ``` inline Tensor hfft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_hfft_symint(self, n, dim, norm); } @@ -234,9 +234,9 @@ inline Tensor hfft( /// ``` inline Tensor ihfft( const Tensor& self, - std::optional n = std::nullopt, + std::optional n = c10::nullopt, int64_t dim = -1, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_ihfft_symint(self, n, dim, norm); } @@ -253,9 +253,9 @@ inline Tensor ihfft( /// ``` inline Tensor hfft2( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_hfft2(self, s, dim, norm); } @@ -273,9 +273,9 @@ inline Tensor hfft2( /// ``` inline Tensor ihfft2( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_ihfft2(self, s, dim, norm); } @@ -292,9 +292,9 @@ inline Tensor ihfft2( /// ``` inline Tensor hfftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_hfftn(self, s, dim, norm); } @@ -312,9 +312,9 @@ inline Tensor hfftn( /// ``` inline Tensor ihfftn( const Tensor& self, - at::OptionalIntArrayRef s = std::nullopt, + at::OptionalIntArrayRef s = c10::nullopt, IntArrayRef dim = {-2, -1}, - std::optional norm = std::nullopt) { + std::optional norm = c10::nullopt) { return torch::fft_ihfftn(self, s, dim, norm); } @@ -364,7 +364,7 @@ inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { /// ``` inline Tensor fftshift( const Tensor& x, - at::OptionalIntArrayRef dim = std::nullopt) { + at::OptionalIntArrayRef dim = c10::nullopt) { return torch::fft_fftshift(x, dim); } @@ -381,7 +381,7 @@ inline Tensor fftshift( /// ``` inline Tensor ifftshift( const Tensor& x, - at::OptionalIntArrayRef dim = std::nullopt) { + at::OptionalIntArrayRef dim = c10::nullopt) { return torch::fft_ifftshift(x, dim); } diff --git a/torch/csrc/api/include/torch/nested.h b/torch/csrc/api/include/torch/nested.h index 2e4365e0031cc0..780aab42304723 100644 --- a/torch/csrc/api/include/torch/nested.h +++ b/torch/csrc/api/include/torch/nested.h @@ -26,7 +26,7 @@ inline at::Tensor nested_tensor( auto out = at::_nested_tensor_from_tensor_list( nested_tensor_data, c10::typeMetaToScalarType(options.dtype()), - std::nullopt, + c10::nullopt, options.device(), options.pinned_memory()); if (options.has_requires_grad() && options.requires_grad()) { @@ -55,7 +55,7 @@ inline at::Tensor nested_tensor( auto out = at::_nested_tensor_from_tensor_list( tensor_list, c10::typeMetaToScalarType(options.dtype()), - std::nullopt, + c10::nullopt, options.device(), options.pinned_memory()); if (options.has_requires_grad() && options.requires_grad()) { @@ -72,10 +72,10 @@ inline at::Tensor nested_tensor( /// ``` inline at::Tensor as_nested_tensor( at::TensorList list, - std::optional dtype = std::nullopt, - std::optional device = std::nullopt) { + std::optional dtype = c10::nullopt, + std::optional device = c10::nullopt) { return at::_nested_tensor_from_tensor_list( - list, dtype, std::nullopt, device, std::nullopt); + list, dtype, c10::nullopt, device, c10::nullopt); } /// Nested to padded tensor @@ -87,7 +87,7 @@ inline at::Tensor as_nested_tensor( inline at::Tensor to_padded_tensor( const at::Tensor& self, double padding, - at::OptionalIntArrayRef output_size = std::nullopt) { + at::OptionalIntArrayRef output_size = c10::nullopt) { return at::nested_to_padded_tensor(self, padding, output_size); } diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 5ae6fcc317602a..89e596f71d143d 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -236,7 +236,7 @@ inline Tensor softmax( std::optional dtype) { Tensor ret; - if (dtype == std::nullopt) { + if (dtype == c10::nullopt) { ret = input.softmax(dim); } else { ret = input.softmax(dim, dtype); @@ -273,7 +273,7 @@ inline Tensor softmin( std::optional dtype) { Tensor ret; - if (dtype == std::nullopt) { + if (dtype == c10::nullopt) { ret = (-input).softmax(dim); } else { ret = (-input).softmax(dim, dtype); @@ -310,7 +310,7 @@ inline Tensor log_softmax( std::optional dtype) { Tensor ret; - if (dtype == std::nullopt) { + if (dtype == c10::nullopt) { ret = input.log_softmax(dim); } else { ret = input.log_softmax(dim, dtype); diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h index 602268ab2eba30..b06b0a3dc1e851 100644 --- a/torch/csrc/api/include/torch/nn/functional/embedding.h +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -31,7 +31,7 @@ inline Tensor embedding( bool sparse) { auto input_ = input; - if (padding_idx != std::nullopt) { + if (padding_idx != c10::nullopt) { if (*padding_idx > 0) { TORCH_CHECK( *padding_idx < weight.size(0), @@ -46,7 +46,7 @@ inline Tensor embedding( padding_idx = -1; } - if (max_norm != std::nullopt) { + if (max_norm != c10::nullopt) { input_ = input_.contiguous(); // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); @@ -149,7 +149,7 @@ inline Tensor embedding_bag( TORCH_CHECK(false, "mode has to be one of sum, mean or max"); } - if (max_norm != std::nullopt) { + if (max_norm != c10::nullopt) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type); } diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index 6a425e606caf26..d1e285d0a0f19a 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -346,7 +346,7 @@ inline Tensor smooth_l1_loss( const Tensor& input, const Tensor& target, SmoothL1LossFuncOptions::reduction_t reduction, - std::optional beta_opt = std::nullopt) { + std::optional beta_opt = c10::nullopt) { if (target.sizes() != input.sizes()) { TORCH_WARN( "Using a target size (", @@ -405,7 +405,7 @@ inline Tensor smooth_l1_loss( const SmoothL1LossFuncOptions& options, double beta) { TORCH_CHECK( - options.beta() == std::nullopt, + options.beta() == c10::nullopt, "expected beta not to be provided in 'options', but got ", options.beta().value()); return detail::smooth_l1_loss(input, target, options.reduction(), beta); diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h index 965cfcd9ac83fa..53bd61839f7451 100644 --- a/torch/csrc/api/include/torch/nn/functional/normalization.h +++ b/torch/csrc/api/include/torch/nn/functional/normalization.h @@ -17,7 +17,7 @@ inline Tensor normalize( int64_t dim, double eps, std::optional out) { - if (out == std::nullopt) { + if (out == c10::nullopt) { auto denom = input.norm(p, dim, true).clamp_min(eps).expand_as(input); return input / denom; } else { @@ -115,7 +115,7 @@ inline Tensor local_response_norm( /*padding=*/0, /*ceil_mode=*/false, /*count_include_pad=*/true, - /*divisor_override=*/std::nullopt) + /*divisor_override=*/c10::nullopt) .squeeze(1); } else { auto sizes = input.sizes(); @@ -132,7 +132,7 @@ inline Tensor local_response_norm( /*padding=*/0, /*ceil_mode=*/false, /*count_include_pad=*/true, - /*divisor_override=*/std::nullopt) + /*divisor_override=*/c10::nullopt) .squeeze(1); div = div.view(sizes); } diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h index 798467c0e0a681..be3009f62201a2 100644 --- a/torch/csrc/api/include/torch/nn/functional/pooling.h +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -820,15 +820,15 @@ inline std::tuple fractional_max_pool2d_with_indices( const std::optional>& output_size, const std::optional>& output_ratio, const Tensor& _random_samples) { - if (output_size == std::nullopt && output_ratio == std::nullopt) { + if (output_size == c10::nullopt && output_ratio == c10::nullopt) { TORCH_CHECK( false, "fractional_max_pool2d requires specifying either ", "an output_size or an output_ratio"); } std::optional> output_size_ = output_size; - if (output_size_ == std::nullopt) { - TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); + if (output_size_ == c10::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); output_size_ = { (int64_t)(static_cast(input.size(-2)) * (*output_ratio.value())[0]), @@ -913,7 +913,7 @@ inline std::tuple fractional_max_pool3d_with_indices( const std::optional>& output_size, const std::optional>& output_ratio, const Tensor& _random_samples) { - if (output_size == std::nullopt && output_ratio == std::nullopt) { + if (output_size == c10::nullopt && output_ratio == c10::nullopt) { TORCH_CHECK( false, "fractional_max_pool3d requires specifying either ", @@ -921,8 +921,8 @@ inline std::tuple fractional_max_pool3d_with_indices( } std::optional> output_size_ = output_size; - if (output_size_ == std::nullopt) { - TORCH_INTERNAL_ASSERT(output_ratio != std::nullopt); + if (output_size_ == c10::nullopt) { + TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt); output_size_ = { (int64_t)(static_cast(input.size(-3)) * (*output_ratio.value())[0]), @@ -1066,7 +1066,7 @@ inline Tensor lp_pool2d( /*padding=*/0, ceil_mode, /*count_include_pad=*/true, - /*divisor_override=*/std::nullopt); + /*divisor_override=*/c10::nullopt); return (torch::sign(out) * relu(torch::abs(out))) .mul(kw * kh) @@ -1116,7 +1116,7 @@ inline Tensor lp_pool3d( /*padding=*/0, ceil_mode, /*count_include_pad=*/true, - /*divisor_override=*/std::nullopt); + /*divisor_override=*/c10::nullopt); return (torch::sign(out) * relu(torch::abs(out))) .mul(kd * kw * kh) diff --git a/torch/csrc/api/include/torch/nn/functional/upsampling.h b/torch/csrc/api/include/torch/nn/functional/upsampling.h index 75707ef091a783..38c5c51f9a475e 100644 --- a/torch/csrc/api/include/torch/nn/functional/upsampling.h +++ b/torch/csrc/api/include/torch/nn/functional/upsampling.h @@ -19,13 +19,13 @@ inline std::vector _interp_output_size( std::optional>, std::optional> closed_over_args) { auto [input, size, scale_factor, recompute_scale_factor] = closed_over_args; - if (size == std::nullopt && scale_factor == std::nullopt) { + if (size == c10::nullopt && scale_factor == c10::nullopt) { TORCH_CHECK(false, "either size or scale_factor should be defined"); } - if (size != std::nullopt && scale_factor != std::nullopt) { + if (size != c10::nullopt && scale_factor != c10::nullopt) { TORCH_CHECK(false, "only one of size or scale_factor should be defined"); } - if (scale_factor != std::nullopt) { + if (scale_factor != c10::nullopt) { if (static_cast(scale_factor.value().size()) != dim) { TORCH_CHECK( false, @@ -36,14 +36,14 @@ inline std::vector _interp_output_size( torch::ArrayRef(*scale_factor)); } } - if (size != std::nullopt) { + if (size != c10::nullopt) { return *size; } - TORCH_INTERNAL_ASSERT(scale_factor != std::nullopt); + TORCH_INTERNAL_ASSERT(scale_factor != c10::nullopt); auto scale_factors = *scale_factor; - if (recompute_scale_factor == std::nullopt) { + if (recompute_scale_factor == c10::nullopt) { // only warn when the scales have floating values since // the result for ints is the same with/without recompute_scale_factor bool is_float_scale_factor = false; @@ -83,14 +83,14 @@ inline Tensor interpolate( bool antialias) { if (std::holds_alternative(mode) || std::get_if(&mode)) { - if (align_corners != std::nullopt) { + if (align_corners != c10::nullopt) { TORCH_CHECK( false, "align_corners option can only be set with the " "interpolating modes: linear | bilinear | bicubic | trilinear"); } } else { - if (align_corners == std::nullopt) { + if (align_corners == c10::nullopt) { TORCH_WARN( "Default upsampling behavior when mode=", enumtype::get_enum_name(mode), @@ -114,8 +114,8 @@ inline Tensor interpolate( auto scale_factor_len = input.dim() - 2; std::vector> scale_factor_list( - scale_factor_len, std::nullopt); - if (scale_factor != std::nullopt && !recompute_scale_factor.value_or(false)) { + scale_factor_len, c10::nullopt); + if (scale_factor != c10::nullopt && !recompute_scale_factor.value_or(false)) { auto _scale_factor_repeated = *scale_factor; scale_factor_list = {}; for (const auto& elem : _scale_factor_repeated) { @@ -181,7 +181,7 @@ inline Tensor interpolate( input, _interp_output_size(3, std::move(closed_over_args))); } else if (input.dim() == 3 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != std::nullopt, "align_corners should be specified."); + align_corners != c10::nullopt, "align_corners should be specified."); return torch::upsample_linear1d( input, _interp_output_size(1, std::move(closed_over_args)), @@ -195,7 +195,7 @@ inline Tensor interpolate( TORCH_CHECK(false, "Got 4D input, but linear mode needs 3D input"); } else if (input.dim() == 4 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != std::nullopt, "align_corners should be specified."); + align_corners != c10::nullopt, "align_corners should be specified."); if (antialias) { return torch::_upsample_bilinear2d_aa( input, @@ -218,7 +218,7 @@ inline Tensor interpolate( TORCH_CHECK(false, "Got 5D input, but bilinear mode needs 4D input"); } else if (input.dim() == 5 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != std::nullopt, "align_corners should be specified."); + align_corners != c10::nullopt, "align_corners should be specified."); return torch::upsample_trilinear3d( input, _interp_output_size(3, std::move(closed_over_args)), @@ -228,7 +228,7 @@ inline Tensor interpolate( scale_factor_list.at(2)); } else if (input.dim() == 4 && std::get_if(&mode)) { TORCH_CHECK( - align_corners != std::nullopt, "align_corners should be specified."); + align_corners != c10::nullopt, "align_corners should be specified."); if (antialias) { return torch::_upsample_bicubic2d_aa( input, diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 0f5e32746936eb..ec76c6b4a6fbc6 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -106,7 +106,7 @@ class BatchNormImplBase : public NormImplBase { this->_check_input_dim(input); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double exponential_average_factor; - if (this->options.momentum() == std::nullopt) { + if (this->options.momentum() == c10::nullopt) { exponential_average_factor = 0.0; } else { exponential_average_factor = this->options.momentum().value(); @@ -116,7 +116,7 @@ class BatchNormImplBase : public NormImplBase { if (this->num_batches_tracked.defined()) { this->num_batches_tracked += 1; if (this->options.momentum() == - std::nullopt) { // use cumulative moving average + c10::nullopt) { // use cumulative moving average exponential_average_factor = 1.0 / this->num_batches_tracked.template item(); } else { // use exponential moving average diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index e44fd44b954abe..9c55254ddb9103 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -350,7 +350,7 @@ class TORCH_API ConvTranspose1dImpl explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = std::nullopt); + const std::optional& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) @@ -392,7 +392,7 @@ class TORCH_API ConvTranspose2dImpl explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = std::nullopt); + const std::optional& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) @@ -434,7 +434,7 @@ class TORCH_API ConvTranspose3dImpl explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_); Tensor forward( const Tensor& input, - const std::optional& output_size = std::nullopt); + const std::optional& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional())}) diff --git a/torch/csrc/api/include/torch/nn/modules/pooling.h b/torch/csrc/api/include/torch/nn/modules/pooling.h index 0fac60edbcde40..6bcdca463b1ba9 100644 --- a/torch/csrc/api/include/torch/nn/modules/pooling.h +++ b/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -507,7 +507,7 @@ class TORCH_API MaxUnpool1dImpl : public MaxUnpoolImpl<1, MaxUnpool1dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = std::nullopt); + const std::optional>& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) @@ -539,7 +539,7 @@ class TORCH_API MaxUnpool2dImpl : public MaxUnpoolImpl<2, MaxUnpool2dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = std::nullopt); + const std::optional>& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) @@ -571,7 +571,7 @@ class TORCH_API MaxUnpool3dImpl : public MaxUnpoolImpl<3, MaxUnpool3dImpl> { Tensor forward( const Tensor& input, const Tensor& indices, - const std::optional>& output_size = std::nullopt); + const std::optional>& output_size = c10::nullopt); protected: FORWARD_HAS_DEFAULT_ARGS({2, AnyValue(std::optional>())}) diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h index 6eaa0c1fb2c73e..869027a241492d 100644 --- a/torch/csrc/api/include/torch/nn/modules/utils.h +++ b/torch/csrc/api/include/torch/nn/modules/utils.h @@ -1,8 +1,8 @@ #pragma once #include +#include #include -#include #include diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index ac6cbc4ea4deab..165212e0e860cd 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -252,7 +252,7 @@ struct TORCH_API SoftmaxFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = std::nullopt; + TORCH_ARG(std::optional, dtype) = c10::nullopt; }; } // namespace functional @@ -293,7 +293,7 @@ struct TORCH_API SoftminFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = std::nullopt; + TORCH_ARG(std::optional, dtype) = c10::nullopt; }; } // namespace functional @@ -334,7 +334,7 @@ struct TORCH_API LogSoftmaxFuncOptions { /// If specified, the input tensor is casted to `dtype` before the operation /// is performed. This is useful for preventing data type overflows. Default: /// None. - TORCH_ARG(std::optional, dtype) = std::nullopt; + TORCH_ARG(std::optional, dtype) = c10::nullopt; }; } // namespace functional @@ -640,10 +640,10 @@ struct TORCH_API MultiheadAttentionOptions { /// add a new batch of zeros to the key and value sequences at dim=1. TORCH_ARG(bool, add_zero_attn) = false; - /// total number of features in key. Default: std::nullopt. + /// total number of features in key. Default: c10::nullopt. TORCH_ARG(int64_t, kdim); - /// total number of features in key. Default: std::nullopt. + /// total number of features in key. Default: c10::nullopt. TORCH_ARG(int64_t, vdim); }; diff --git a/torch/csrc/api/include/torch/nn/options/embedding.h b/torch/csrc/api/include/torch/nn/options/embedding.h index a3d2fdb72f54da..20eacf90733552 100644 --- a/torch/csrc/api/include/torch/nn/options/embedding.h +++ b/torch/csrc/api/include/torch/nn/options/embedding.h @@ -28,10 +28,10 @@ struct TORCH_API EmbeddingOptions { /// Embedding, the embedding vector at `padding_idx` will default to all /// zeros, but can be updated to another value to be used as the padding /// vector. - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -55,10 +55,10 @@ struct TORCH_API EmbeddingFromPretrainedOptions { /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -84,10 +84,10 @@ struct TORCH_API EmbeddingFuncOptions { /// If specified, the entries at `padding_idx` do not contribute to the /// gradient; therefore, the embedding vector at `padding_idx` is not updated /// during training, i.e. it remains as a fixed "pad". - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -120,7 +120,7 @@ struct TORCH_API EmbeddingBagOptions { TORCH_ARG(int64_t, embedding_dim); /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -148,7 +148,7 @@ struct TORCH_API EmbeddingBagOptions { /// zeros, but can be updated to another value to be used as the padding /// vector. Note that the embedding vector at `padding_idx` is excluded from /// the reduction. - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; }; // ============================================================================ @@ -161,7 +161,7 @@ struct TORCH_API EmbeddingBagFromPretrainedOptions { TORCH_ARG(bool, freeze) = true; /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -184,7 +184,7 @@ struct TORCH_API EmbeddingBagFromPretrainedOptions { /// gradient; therefore, the embedding vector at padding_idx is not updated /// during training, i.e. it remains as a fixed "pad". Note that the embedding /// vector at `padding_idx` is excluded from the reduction. - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; }; // ============================================================================ @@ -205,7 +205,7 @@ struct TORCH_API EmbeddingBagFuncOptions { TORCH_ARG(torch::Tensor, offsets) = Tensor(); /// If given, each embedding vector with norm larger than `max_norm` is /// renormalized to have norm `max_norm`. - TORCH_ARG(std::optional, max_norm) = std::nullopt; + TORCH_ARG(std::optional, max_norm) = c10::nullopt; /// The p of the p-norm to compute for the `max_norm` option. Default ``2``. TORCH_ARG(double, norm_type) = 2.; /// If given, this will scale gradients by the inverse of frequency of the @@ -233,7 +233,7 @@ struct TORCH_API EmbeddingBagFuncOptions { /// gradient; therefore, the embedding vector at padding_idx is not updated /// during training, i.e. it remains as a fixed "pad". Note that the embedding /// vector at `padding_idx` is excluded from the reduction. - TORCH_ARG(std::optional, padding_idx) = std::nullopt; + TORCH_ARG(std::optional, padding_idx) = c10::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index 5a6e7aa3ab20be..f1fc7a4d411156 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -451,7 +451,7 @@ struct TORCH_API TripletMarginWithDistanceLossOptions { /// closeness of two tensors. If not specified, `F::pairwise_distance` will /// be used. Default: nullopt TORCH_ARG(std::optional, distance_function) = - std::nullopt; + c10::nullopt; /// Specifies a nonnegative margin representing the minimum difference /// between the positive and negative distances required for the loss to be 0. /// Larger margins penalize cases where the negative examples are not distance @@ -548,7 +548,7 @@ struct TORCH_API SmoothL1LossOptions { /// Specifies the threshold at which to change between L1 and L2 loss. /// If beta is not specified, a value of 1.0 will be used. /// Default: nullopt - TORCH_ARG(std::optional, beta) = std::nullopt; + TORCH_ARG(std::optional, beta) = c10::nullopt; }; namespace functional { diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h index 4b6dcd6ffe0c27..a1e5b1a0aeab1c 100644 --- a/torch/csrc/api/include/torch/nn/options/normalization.h +++ b/torch/csrc/api/include/torch/nn/options/normalization.h @@ -133,7 +133,7 @@ struct TORCH_API NormalizeFuncOptions { TORCH_ARG(double, eps) = 1e-12; /// the output tensor. If `out` is used, this /// operation won't be differentiable. - TORCH_ARG(std::optional, out) = std::nullopt; + TORCH_ARG(std::optional, out) = c10::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h index 75408890e7cd12..8f6cee99bff6ae 100644 --- a/torch/csrc/api/include/torch/nn/options/pooling.h +++ b/torch/csrc/api/include/torch/nn/options/pooling.h @@ -32,7 +32,7 @@ struct AvgPoolOptions { /// if specified, it will be used as divisor, otherwise size of the pooling /// region will be used. - TORCH_ARG(std::optional, divisor_override) = std::nullopt; + TORCH_ARG(std::optional, divisor_override) = c10::nullopt; }; /// `AvgPoolOptions` specialized for the `AvgPool1d` module. @@ -401,7 +401,7 @@ struct MaxUnpoolFuncOptions { TORCH_ARG(ExpandingArray, padding) = 0; /// the targeted output size - TORCH_ARG(std::optional>, output_size) = std::nullopt; + TORCH_ARG(std::optional>, output_size) = c10::nullopt; }; /// `MaxUnpoolFuncOptions` specialized for @@ -450,12 +450,12 @@ struct FractionalMaxPoolOptions { TORCH_ARG(ExpandingArray, kernel_size); /// the target output size of the image - TORCH_ARG(std::optional>, output_size) = std::nullopt; + TORCH_ARG(std::optional>, output_size) = c10::nullopt; /// If one wants to have an output size as a ratio of the input size, this /// option can be given. This has to be a number or tuple in the range (0, 1) using ExpandingArrayDouble = torch::ExpandingArray; - TORCH_ARG(std::optional, output_ratio) = std::nullopt; + TORCH_ARG(std::optional, output_ratio) = c10::nullopt; TORCH_ARG(torch::Tensor, _random_samples) = Tensor(); }; diff --git a/torch/csrc/api/include/torch/nn/options/upsampling.h b/torch/csrc/api/include/torch/nn/options/upsampling.h index df8eb194180acc..21df2b89998de5 100644 --- a/torch/csrc/api/include/torch/nn/options/upsampling.h +++ b/torch/csrc/api/include/torch/nn/options/upsampling.h @@ -20,10 +20,10 @@ namespace nn { /// ``` struct TORCH_API UpsampleOptions { /// output spatial sizes. - TORCH_ARG(std::optional>, size) = std::nullopt; + TORCH_ARG(std::optional>, size) = c10::nullopt; /// multiplier for spatial size. - TORCH_ARG(std::optional>, scale_factor) = std::nullopt; + TORCH_ARG(std::optional>, scale_factor) = c10::nullopt; /// the upsampling algorithm: one of "nearest", "linear", "bilinear", /// "bicubic" and "trilinear". Default: "nearest" @@ -40,7 +40,7 @@ struct TORCH_API UpsampleOptions { /// aligned, and thus preserving the values at those pixels. This only has /// effect when :attr:`mode` is "linear", "bilinear", "bicubic", or /// "trilinear". Default: "False" - TORCH_ARG(std::optional, align_corners) = std::nullopt; + TORCH_ARG(std::optional, align_corners) = c10::nullopt; }; namespace functional { @@ -65,10 +65,10 @@ struct TORCH_API InterpolateFuncOptions { mode_t; /// output spatial sizes. - TORCH_ARG(std::optional>, size) = std::nullopt; + TORCH_ARG(std::optional>, size) = c10::nullopt; /// multiplier for spatial size. - TORCH_ARG(std::optional>, scale_factor) = std::nullopt; + TORCH_ARG(std::optional>, scale_factor) = c10::nullopt; /// the upsampling algorithm: one of "nearest", "linear", "bilinear", /// "bicubic", "trilinear", "area", "nearest-exact". Default: "nearest" @@ -83,7 +83,7 @@ struct TORCH_API InterpolateFuncOptions { /// this operation *independent* of input size when `scale_factor` is /// kept the same. It is *required* when interpolating mode is "linear", /// "bilinear", "bicubic" or "trilinear". Default: "False" - TORCH_ARG(std::optional, align_corners) = std::nullopt; + TORCH_ARG(std::optional, align_corners) = c10::nullopt; /// recompute the scale_factor for use in the /// interpolation calculation. When `scale_factor` is passed as a parameter, @@ -95,7 +95,7 @@ struct TORCH_API InterpolateFuncOptions { /// used in the interpolation computation. Note that when `scale_factor` is /// floating-point, the recomputed scale_factor may differ from the one passed /// in due to rounding and precision issues. - TORCH_ARG(std::optional, recompute_scale_factor) = std::nullopt; + TORCH_ARG(std::optional, recompute_scale_factor) = c10::nullopt; /// flag to apply anti-aliasing. Using anti-alias /// option together with :attr:`align_corners` equals "False", interpolation diff --git a/torch/csrc/api/include/torch/nn/options/vision.h b/torch/csrc/api/include/torch/nn/options/vision.h index a5204f0dffb624..c012b40d21f695 100644 --- a/torch/csrc/api/include/torch/nn/options/vision.h +++ b/torch/csrc/api/include/torch/nn/options/vision.h @@ -28,7 +28,7 @@ struct TORCH_API GridSampleFuncOptions { /// padding mode for outside grid values. Default: Zeros TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros; /// Specifies perspective to pixel as point. Default: false - TORCH_ARG(std::optional, align_corners) = std::nullopt; + TORCH_ARG(std::optional, align_corners) = c10::nullopt; }; } // namespace functional diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h index 8a2a569c03335c..fbb533662c7be3 100644 --- a/torch/csrc/api/include/torch/nn/utils/clip_grad.h +++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -64,7 +64,7 @@ inline double clip_grad_norm_( // synchronizing the CPU and the gradients' device until the very end to // preserve async execution on the device. When checking for finite-ness, this // optional ensures we only sync once. - std::optional total_norm = std::nullopt; + std::optional total_norm = c10::nullopt; if (error_if_nonfinite) { total_norm = total_norm_tensor.item().toDouble(); TORCH_CHECK( @@ -79,7 +79,7 @@ inline double clip_grad_norm_( auto clip_coef = max_norm / (total_norm_tensor + 1e-6); auto clip_coef_clamped = - torch::clamp(clip_coef, std::nullopt /* min */, 1.0 /* max */); + torch::clamp(clip_coef, c10::nullopt /* min */, 1.0 /* max */); for (auto& param : params_with_grad) { param.grad().data().mul_(clip_coef_clamped); } diff --git a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h index b8bfee33473f2a..6f62d483c4d8b8 100644 --- a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h +++ b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -15,7 +15,7 @@ inline std::optional _check_param_device( const torch::Tensor& param, std::optional old_param_device) { // Meet the first parameter - if (old_param_device == std::nullopt) { + if (old_param_device == c10::nullopt) { old_param_device = param.is_cuda() ? param.get_device() : -1; } else { bool warn = false; diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index 0832afff5f8f20..001b0cd33f2596 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -17,11 +17,11 @@ struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { LBFGSOptions(double lr = 1); TORCH_ARG(double, lr) = 1; TORCH_ARG(int64_t, max_iter) = 20; - TORCH_ARG(std::optional, max_eval) = std::nullopt; + TORCH_ARG(std::optional, max_eval) = c10::nullopt; TORCH_ARG(double, tolerance_grad) = 1e-7; TORCH_ARG(double, tolerance_change) = 1e-9; TORCH_ARG(int64_t, history_size) = 100; - TORCH_ARG(std::optional, line_search_fn) = std::nullopt; + TORCH_ARG(std::optional, line_search_fn) = c10::nullopt; public: void serialize(torch::serialize::InputArchive& archive) override; @@ -45,7 +45,7 @@ struct TORCH_API LBFGSParamState TORCH_ARG(std::deque, old_dirs); TORCH_ARG(std::deque, old_stps); TORCH_ARG(std::deque, ro); - TORCH_ARG(std::optional>, al) = std::nullopt; + TORCH_ARG(std::optional>, al) = c10::nullopt; public: void serialize(torch::serialize::InputArchive& archive) override; @@ -66,13 +66,13 @@ class TORCH_API LBFGS : public Optimizer { TORCH_CHECK( param_groups_.size() == 1, "LBFGS doesn't support per-parameter options (parameter groups)"); - if (defaults.max_eval() == std::nullopt) { + if (defaults.max_eval() == c10::nullopt) { auto max_eval_val = (defaults.max_iter() * 5) / 4; static_cast(param_groups_[0].options()) .max_eval(max_eval_val); static_cast(*defaults_.get()).max_eval(max_eval_val); } - _numel_cache = std::nullopt; + _numel_cache = c10::nullopt; } explicit LBFGS(std::vector params, LBFGSOptions defaults = {}) : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {} diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index dd5bd600ff3e79..1f448e4fffd61c 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -186,22 +186,22 @@ class TORCH_API Optimizer { }; /* How do we decide whether to serialize undefined tensors or - std::nullopt values into the output archive? + c10::nullopt values into the output archive? Answer: we strictly follow the behavior of Python API. To be more specific: For optimizer options: a) For undefined tensor: currently no tensor is used as an options argument in -Python API, so we don't need to worry about it now. b) For std::nullopt value: -we serialize std::nullopt values into the output archive, to follow the exact +Python API, so we don't need to worry about it now. b) For c10::nullopt value: +we serialize c10::nullopt values into the output archive, to follow the exact same behavior as Python API. For optimizer param state: a) For undefined tensor: in param state, undefined tensor in C++ impl is equivalent to missing key in Python impl. Since we don't serialize missing keys in Python API, we skip undefined tensors when serializing the param state. b) -For std::nullopt value: in param state, std::nullopt value in C++ impl is +For c10::nullopt value: in param state, c10::nullopt value in C++ impl is equivalent to missing key in Python impl. Since we don't serialize missing keys -in Python API, we skip std::nullopt values when serializing the param state. */ +in Python API, we skip c10::nullopt values when serializing the param state. */ /// Serializes an `Optimizer` into an `OutputArchive`. TORCH_API serialize::OutputArchive& operator<<( diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index 3650cfcfea23f9..f77b34aad0bd43 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -1,10 +1,10 @@ #pragma once #include +#include #include #include #include -#include #include #include @@ -76,27 +76,27 @@ class TORCH_API InputArchive final { /// is not specified, the module is loaded to the original device. void load_from( const std::string& filename, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); /// Loads the `InputArchive` from a serialized representation stored in the /// given `stream`. Storage are remapped using device option. If device /// is not specified, the module is loaded to the original device. void load_from( std::istream& stream, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); // Loads given the specified flat array. void load_from( const char* data, size_t size, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); // Loads given the specified read and size functions. void load_from( const std::function& read_func, const std::function& size_func, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); // Returns the vector of keys in the input archive. std::vector keys(); diff --git a/torch/csrc/api/include/torch/types.h b/torch/csrc/api/include/torch/types.h index febda7ac6bb852..8a23cd122b8d1d 100644 --- a/torch/csrc/api/include/torch/types.h +++ b/torch/csrc/api/include/torch/types.h @@ -2,7 +2,7 @@ #include -#include +#include #include #include @@ -38,7 +38,7 @@ namespace torch { // the `func()` function defined in `at::` namespace is always hidden. using namespace at; // NOLINT -using std::nullopt; +using c10::nullopt; using std::optional; using Dtype = at::ScalarType; diff --git a/torch/csrc/api/src/jit.cpp b/torch/csrc/api/src/jit.cpp index 07064dbdc9e786..16d9d0040a6592 100644 --- a/torch/csrc/api/src/jit.cpp +++ b/torch/csrc/api/src/jit.cpp @@ -11,7 +11,7 @@ namespace jit { std::shared_ptr compile(const std::string& source) { auto module = std::make_shared(); - module->define(std::nullopt, source, nativeResolver(), nullptr); + module->define(c10::nullopt, source, nativeResolver(), nullptr); return module; } diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index 518072d0653f12..56218ad091de5d 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -130,7 +130,7 @@ void SoftmaxImpl::pretty_print(std::ostream& stream) const { } Tensor SoftmaxImpl::forward(const Tensor& input) { - return F::detail::softmax(input, options.dim(), std::nullopt); + return F::detail::softmax(input, options.dim(), c10::nullopt); } // ============================================================================ @@ -144,7 +144,7 @@ void SoftminImpl::pretty_print(std::ostream& stream) const { } Tensor SoftminImpl::forward(const Tensor& input) { - return F::detail::softmin(input, options.dim(), std::nullopt); + return F::detail::softmin(input, options.dim(), c10::nullopt); } // ============================================================================ @@ -159,7 +159,7 @@ void LogSoftmaxImpl::pretty_print(std::ostream& stream) const { } Tensor LogSoftmaxImpl::forward(const Tensor& input) { - return F::detail::log_softmax(input, options.dim(), std::nullopt); + return F::detail::log_softmax(input, options.dim(), c10::nullopt); } // ============================================================================ @@ -174,7 +174,7 @@ Tensor Softmax2dImpl::forward(const Tensor& input) { TORCH_CHECK( input.dim() == 4 || input.dim() == 3, "Softmax2d requires a 3D or 4D tensor as input"); - return F::detail::softmax(input, /*dim=*/-3, std::nullopt); + return F::detail::softmax(input, /*dim=*/-3, c10::nullopt); } // ============================================================================ diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 26e52df637f852..197c3cf0725cd0 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -176,7 +176,7 @@ std::vector ConvTransposeNdImpl::_output_padding( std::vector ret; std::optional output_size_ = output_size; - if (output_size_ == std::nullopt) { + if (output_size_ == c10::nullopt) { ret = at::IntArrayRef(this->options.output_padding()).vec(); } else { auto k = input.dim() - 2; diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 4c6683d1f36b58..553a93875e1784 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -20,7 +20,7 @@ EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options_) } void EmbeddingImpl::reset() { - if (options.padding_idx() != std::nullopt) { + if (options.padding_idx() != c10::nullopt) { if (*options.padding_idx() > 0) { TORCH_CHECK( *options.padding_idx() < options.num_embeddings(), @@ -50,7 +50,7 @@ void EmbeddingImpl::reset() { void EmbeddingImpl::reset_parameters() { torch::nn::init::normal_(weight); - if (options.padding_idx() != std::nullopt) { + if (options.padding_idx() != c10::nullopt) { torch::NoGradGuard no_grad; weight[*options.padding_idx()].fill_(0); } @@ -59,10 +59,10 @@ void EmbeddingImpl::reset_parameters() { void EmbeddingImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Embedding(num_embeddings=" << options.num_embeddings() << ", embedding_dim=" << options.embedding_dim(); - if (options.padding_idx() != std::nullopt) { + if (options.padding_idx() != c10::nullopt) { stream << ", padding_idx=" << *options.padding_idx(); } - if (options.max_norm() != std::nullopt) { + if (options.max_norm() != c10::nullopt) { stream << ", max_norm=" << *options.max_norm(); } if (options.norm_type() != 2) { @@ -154,7 +154,7 @@ void EmbeddingBagImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::EmbeddingBag(num_embeddings=" << options.num_embeddings() << ", embedding_dim=" << options.embedding_dim(); - if (options.max_norm() != std::nullopt) { + if (options.max_norm() != c10::nullopt) { stream << ", max_norm=" << *options.max_norm(); } if (options.norm_type() != 2) { diff --git a/torch/csrc/api/src/nn/modules/pooling.cpp b/torch/csrc/api/src/nn/modules/pooling.cpp index a02d8cd712aa09..0b11b914dcc1c7 100644 --- a/torch/csrc/api/src/nn/modules/pooling.cpp +++ b/torch/csrc/api/src/nn/modules/pooling.cpp @@ -281,19 +281,19 @@ FractionalMaxPool2dImpl::FractionalMaxPool2dImpl( void FractionalMaxPool2dImpl::reset() { _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == std::nullopt && - options.output_ratio() == std::nullopt) { + if (options.output_size() == c10::nullopt && + options.output_ratio() == c10::nullopt) { TORCH_CHECK( false, "FractionalMaxPool2d requires specifying either ", "an output size, or a pooling ratio"); } - if (options.output_size() != std::nullopt && - options.output_ratio() != std::nullopt) { + if (options.output_size() != c10::nullopt && + options.output_ratio() != c10::nullopt) { TORCH_CHECK( false, "only one of output_size and output_ratio may be specified"); } - if (options.output_ratio() != std::nullopt) { + if (options.output_ratio() != c10::nullopt) { at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && @@ -340,19 +340,19 @@ FractionalMaxPool3dImpl::FractionalMaxPool3dImpl( void FractionalMaxPool3dImpl::reset() { _random_samples = register_buffer("_random_samples", options._random_samples()); - if (options.output_size() == std::nullopt && - options.output_ratio() == std::nullopt) { + if (options.output_size() == c10::nullopt && + options.output_ratio() == c10::nullopt) { TORCH_CHECK( false, "FractionalMaxPool3d requires specifying either ", "an output size, or a pooling ratio"); } - if (options.output_size() != std::nullopt && - options.output_ratio() != std::nullopt) { + if (options.output_size() != c10::nullopt && + options.output_ratio() != c10::nullopt) { TORCH_CHECK( false, "only one of output_size and output_ratio may be specified"); } - if (options.output_ratio() != std::nullopt) { + if (options.output_ratio() != c10::nullopt) { at::ArrayRef output_ratio = at::ArrayRef(options.output_ratio().value()); if (!(0 < output_ratio[0] && output_ratio[0] < 1 && 0 < output_ratio[1] && diff --git a/torch/csrc/api/src/nn/modules/upsampling.cpp b/torch/csrc/api/src/nn/modules/upsampling.cpp index 378d5aadb92031..8e7bb2fe33cd84 100644 --- a/torch/csrc/api/src/nn/modules/upsampling.cpp +++ b/torch/csrc/api/src/nn/modules/upsampling.cpp @@ -15,7 +15,7 @@ void UpsampleImpl::reset() {} void UpsampleImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Upsample("; - if (options.scale_factor() != std::nullopt) { + if (options.scale_factor() != c10::nullopt) { stream << "scale_factor=" << at::ArrayRef(*options.scale_factor()); } else { stream << "size=" << at::ArrayRef(*options.size()); @@ -43,7 +43,7 @@ Tensor UpsampleImpl::forward(const Tensor& input) { options.scale_factor(), mode, options.align_corners(), - std::nullopt, + c10::nullopt, false); } diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index dbf17f718614a0..10739be6238697 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -68,7 +68,7 @@ bool if_container_equal(T lhs, T rhs) { bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) { auto isNull = [](const std::optional>& val) { - return val == std::nullopt; + return val == c10::nullopt; }; return (lhs.func_evals() == rhs.func_evals()) && (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) && @@ -97,7 +97,7 @@ void LBFGSParamState::serialize( _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro); // Python version only serializes state vars if explicitly defined - if (al() != std::nullopt) { + if (al() != c10::nullopt) { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al); } } @@ -131,7 +131,7 @@ Tensor LBFGS::_gather_flat_grad() { } int64_t LBFGS::_numel() { - if (_numel_cache == std::nullopt) { + if (_numel_cache == c10::nullopt) { auto res = 0; for (const auto& p : param_groups_.at(0).params()) { res += p.numel(); @@ -194,12 +194,12 @@ static double _cubic_interpolate( double x2, double f2, double g2, - std::optional> bounds = std::nullopt) { + std::optional> bounds = c10::nullopt) { // ported from https://github.com/torch/optim/blob/master/polyinterp.lua // Compute bounds of interpolation area // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double xmin_bound, xmax_bound; - if (bounds != std::nullopt) { + if (bounds != c10::nullopt) { std::tie(xmin_bound, xmax_bound) = *bounds; } else { std::tie(xmin_bound, xmax_bound) = @@ -509,7 +509,7 @@ Tensor LBFGS::step(LossClosure closure) { // multiplied by the gradient int64_t num_old = static_cast(old_dirs.size()); - if (state.al() == std::nullopt) { + if (state.al() == c10::nullopt) { state.al(std::vector(history_size)); } auto& al = state.al(); @@ -557,7 +557,7 @@ Tensor LBFGS::step(LossClosure closure) { // optional line search: user function auto ls_func_evals = 0; - if (line_search_fn != std::nullopt) { + if (line_search_fn != c10::nullopt) { TORCH_CHECK( *line_search_fn == "strong_wolfe", "only 'strong_wolfe' is supported"); @@ -627,7 +627,7 @@ void LBFGS::load(serialize::InputArchive& archive) { TORCH_WARN( "Your serialized LBFGS optimizer is still using the old serialization format. " "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque " - "and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values." + "and al will be set to c10::nullopt because the old LBFGS optimizer didn't save these values." "You should re-save your LBFGS optimizer to use the new serialization format."); Tensor d, t, H_diag, prev_flat_grad, prev_loss; std::deque old_dirs, old_stps; diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 8644b6193e0be8..852f4eab1b52b1 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -93,20 +93,20 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { void InputArchive::load_from( const std::string& filename, - std::optional device /*= std::nullopt*/) { + std::optional device /*= c10::nullopt*/) { module_ = torch::jit::load(filename, std::move(device)); } void InputArchive::load_from( std::istream& stream, - std::optional device /*= std::nullopt*/) { + std::optional device /*= c10::nullopt*/) { module_ = torch::jit::load(stream, std::move(device)); } void InputArchive::load_from( const char* data, size_t size, - std::optional device /*= std::nullopt*/) { + std::optional device /*= c10::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { public: @@ -136,7 +136,7 @@ void InputArchive::load_from( void InputArchive::load_from( const std::function& read_func, const std::function& size_func, - std::optional device /*= std::nullopt*/) { + std::optional device /*= c10::nullopt*/) { using caffe2::serialize::ReadAdapterInterface; class OurAdapter : public ReadAdapterInterface { public: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7ca1a172096817..9d897c667c906f 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -630,7 +630,7 @@ Tensor div_tensor_self_backward( T other, ScalarType self_st) { return div_tensor_self_backward( - grad, std::move(other), self_st, std::nullopt); + grad, std::move(other), self_st, c10::nullopt); } template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType); template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType); @@ -652,7 +652,7 @@ Tensor div_tensor_other_backward( const Tensor& grad, const Tensor& self, const Tensor& other) { - return div_tensor_other_backward(grad, self, other, std::nullopt); + return div_tensor_other_backward(grad, self, other, c10::nullopt); } Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { @@ -1282,12 +1282,12 @@ Tensor convolution_jvp( at::SymIntArrayRef output_padding, const c10::SymInt& groups) { auto bias_t_opt = - bias_t.defined() ? std::optional(bias_t) : std::nullopt; + bias_t.defined() ? std::optional(bias_t) : c10::nullopt; return ( at::convolution_symint( input_t, weight_p, - std::nullopt, + c10::nullopt, stride, padding, dilation, @@ -1324,12 +1324,12 @@ Tensor _convolution_jvp( bool cudnn_enabled, bool allow_tf32) { auto bias_t_opt = - bias_t.defined() ? std::optional(bias_t) : std::nullopt; + bias_t.defined() ? std::optional(bias_t) : c10::nullopt; return ( at::_convolution_symint( input_t, weight_p, - std::nullopt, + c10::nullopt, stride, padding, dilation, @@ -6193,7 +6193,7 @@ Tensor batch_norm_jvp( std::optional result_p = weight_p.defined() ? std::optional((input_p - mean_p) * invstd_p) - : std::nullopt; + : c10::nullopt; return _affine_jvp( result_p, result_t, @@ -6232,7 +6232,7 @@ Tensor layer_norm_jvp( std::optional result_p = weight_p.defined() ? std::optional((input_p - mean_p) * invstd_p) - : std::nullopt; + : c10::nullopt; return _affine_jvp( result_p, result_t, @@ -6273,7 +6273,7 @@ Tensor group_norm_jvp( /*eps=*/0) .view(input_shape); - std::optional result_p = std::nullopt; + std::optional result_p = c10::nullopt; if (weight_p.defined()) { std::vector view_size(input_t_reshaped.dim(), 1); view_size[1] = input_t_reshaped.size(1); @@ -6706,7 +6706,7 @@ std::tuple _cudnn_convolution_backward( grad_output, self, weight, - std::nullopt, + c10::nullopt, stride, padding, dilation, @@ -6956,7 +6956,7 @@ Tensor to_sparse_backward( if (self_layout == c10::kStrided) { return grad.to_dense(); } else { - OptionalIntArrayRef blocksize = std::nullopt; + OptionalIntArrayRef blocksize = c10::nullopt; if (self_blocksize.has_value()) { blocksize = c10::asIntArrayRefSlowOpt(*self_blocksize); } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3c461dd88ee56a..dedff70be1ba34 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -39,7 +39,7 @@ TORCH_API inline std::optional wrap_opt_if( const Tensor& t, const bool cond) { using OptTensor = std::optional; - return cond ? OptTensor(t) : static_cast(std::nullopt); + return cond ? OptTensor(t) : static_cast(c10::nullopt); } TORCH_API Tensor diff --git a/torch/csrc/autograd/TraceTypeManual.cpp b/torch/csrc/autograd/TraceTypeManual.cpp index 1473058a3a53df..46e4014d8dd139 100644 --- a/torch/csrc/autograd/TraceTypeManual.cpp +++ b/torch/csrc/autograd/TraceTypeManual.cpp @@ -1,11 +1,11 @@ #include #include #include +#include #include #include #include #include -#include using namespace at; diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 92096dca9a6989..20f66694677e82 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -10,7 +11,6 @@ #include #include #include -#include #include diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 3b598898f80c4a..d5fe8a70dae177 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -217,7 +217,7 @@ inline at::Tensor as_view( tensor, diff_view_meta->get_backward_view().chain( base, tensor, std::move(view_func), std::move(rev_view_func)), - std::nullopt, + c10::nullopt, /*shared_view_info*/ true, creation_meta, allow_tensor_metadata_change); @@ -225,7 +225,7 @@ inline at::Tensor as_view( return make_variable_differentiable_view( tensor, ViewInfo(base, std::move(view_func), std::move(rev_view_func)), - std::nullopt, + c10::nullopt, /*shared_view_info*/ true, creation_meta, allow_tensor_metadata_change); diff --git a/torch/csrc/autograd/autograd.h b/torch/csrc/autograd/autograd.h index bd5d4a462102b2..94ee179225a4ca 100644 --- a/torch/csrc/autograd/autograd.h +++ b/torch/csrc/autograd/autograd.h @@ -47,7 +47,7 @@ namespace torch::autograd { TORCH_API void backward( const variable_list& tensors, const variable_list& grad_tensors = {}, - std::optional retain_graph = std::nullopt, + std::optional retain_graph = c10::nullopt, bool create_graph = false, const variable_list& inputs = {}); @@ -81,7 +81,7 @@ TORCH_API variable_list grad( const variable_list& outputs, const variable_list& inputs, const variable_list& grad_outputs = {}, - std::optional retain_graph = std::nullopt, + std::optional retain_graph = c10::nullopt, bool create_graph = false, bool allow_unused = false); diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index f922c3fc763260..eff2a27c105f36 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -345,7 +345,7 @@ static void autogradNotImplementedFallbackImpl( [&](size_t idx, size_t _, const at::Tensor& t) { storage_saved.push_back( t.has_storage() ? std::optional(t.storage()) - : std::nullopt); + : c10::nullopt); impl_saved.push_back(t.getIntrusivePtr()); }, &stack_args_copy, diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index cb9f5caca0eef1..cfacbf0e3be7fe 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -735,10 +735,10 @@ void GraphTask::exec_post_processing() { for (const auto& leaf_stream : leaf_streams) { // stash_current_cuda/privateuse1_streams() stashed streams for all device // IDs that already had a CUDA/privateuse1 context before the GraphTask - // executed. For inactive devices, it stashed a std::nullopt. I don't + // executed. For inactive devices, it stashed a c10::nullopt. I don't // expect GraphTask's backward pass ran leaf nodes on any new devices, so // the stashed streams should be enough. If leaf_stream.device_index() - // happens to be for a new device, operator* on the std::nullopt should + // happens to be for a new device, operator* on the c10::nullopt should // throw an error. const auto caller_current_stream = // NOLINTNEXTLINE(bugprone-unchecked-optional-access) @@ -1554,7 +1554,7 @@ void GraphTask::stash_current_streams() { idx)) { caller_current_streams_[idx] = guard.getStream({accelerator, idx}); } else { - caller_current_streams_[idx] = std::nullopt; + caller_current_streams_[idx] = c10::nullopt; } } } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 4f7f53c90ec1ed..c8c3538a061f17 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -242,14 +242,14 @@ struct TORCH_API Node : std::enable_shared_from_this { std::optional stream() { auto opt_device_type = at::getAccelerator(); if (!opt_device_type.has_value()) { - return std::nullopt; + return c10::nullopt; } for (const auto& metadata : input_metadata_) { if (metadata.device().type() == opt_device_type.value()) return metadata.stream(); } - return std::nullopt; + return c10::nullopt; } void clear_input_metadata() { diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 99597a73762ff7..2efde9d5f2f2e6 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -224,7 +224,7 @@ struct TORCH_API AccumulateGrad : public Node { // variable_grad += new_grad; // } else { // result = at::empty_strided(variable.sizes(), variable.strides(), - // variable.options().memory_format(std::nullopt)); + // variable.options().memory_format(c10::nullopt)); // update_grad(at::native::add_out(result, variable_grad, // new_grad, 1.0); // } diff --git a/torch/csrc/autograd/functions/comm.cpp b/torch/csrc/autograd/functions/comm.cpp index 5093f51e7eff88..1aed18cb79a5ee 100644 --- a/torch/csrc/autograd/functions/comm.cpp +++ b/torch/csrc/autograd/functions/comm.cpp @@ -105,7 +105,7 @@ variable_list Gather::apply(variable_list&& inputs) { std::move(source_devices), std::move(input_sizes), dim_, - /*streams=*/std::nullopt, + /*streams=*/c10::nullopt, /*unsqueeze_scalars=*/unsqueeze_scalars); grad_fn->set_next_edges(collect_next_edges(inputs)); } diff --git a/torch/csrc/autograd/functions/comm.h b/torch/csrc/autograd/functions/comm.h index 2730827a1eb3c4..0924cd030fcef8 100644 --- a/torch/csrc/autograd/functions/comm.h +++ b/torch/csrc/autograd/functions/comm.h @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include @@ -17,10 +17,10 @@ namespace autograd { struct TORCH_CUDA_CU_API Scatter : public Node { explicit Scatter( std::vector devices, - std::optional> chunk_sizes = std::nullopt, + std::optional> chunk_sizes = c10::nullopt, int64_t dim = 0, std::optional>> streams = - std::nullopt, + c10::nullopt, bool unsqueeze_scalars = false); ~Scatter() override; diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index b22199ee1ad696..e6a907ee2f0a40 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1084,7 +1084,7 @@ static PyObject* push_on_torch_dispatch_stack( using c10::impl::TorchDispatchModeKey; // When we push a mode onto the mode stack, we need to // check if it's an "infra" mode, by checking its _mode_key attribute. - std::optional mode_key = std::nullopt; + std::optional mode_key = c10::nullopt; py::object maybe_mode_key_obj = PyObject_FastGetAttrString(arg, "_mode_key"); if (maybe_mode_key_obj) { @@ -1108,7 +1108,7 @@ static PyObject* pop_torch_dispatch_stack( PyObject* _unused, PyObject* maybe_mode_key) { HANDLE_TH_ERRORS - std::optional mode_key = std::nullopt; + std::optional mode_key = c10::nullopt; PyObject* r = nullptr; if (maybe_mode_key != Py_None) { mode_key = py::cast(maybe_mode_key); @@ -1174,7 +1174,7 @@ static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { auto mode_key = py::cast(arg); auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key); - if (maybe_mode == std::nullopt) { + if (maybe_mode == c10::nullopt) { Py_RETURN_NONE; } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) @@ -1190,7 +1190,7 @@ static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) { auto mode_key = py::cast(arg); const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key); - if (maybe_mode == std::nullopt) { + if (maybe_mode == c10::nullopt) { Py_RETURN_NONE; } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index f2b08e364318a4..6c12bbadc5d2d5 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include @@ -159,7 +159,7 @@ void InputBuffer::add( // Accumulation happens on the var device's default stream. TORCH_INTERNAL_ASSERT(device_of(var)); - std::optional opt_accumulate_stream = std::nullopt; + std::optional opt_accumulate_stream = c10::nullopt; const auto device_type = device_of(var).value().type(); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) if (device_of(var)->is_cuda() || device_of(var)->is_privateuseone()) { @@ -179,7 +179,7 @@ void InputBuffer::add( record_stream_any_impl(var, *opt_accumulate_stream); } } else { - std::optional opt_sync_stream = std::nullopt; + std::optional opt_sync_stream = c10::nullopt; const auto guard = c10::impl::VirtualGuardImpl{device_type}; if (on_consumer && !on_producer) { // (3a) diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index e445ef897fc1aa..7e471ef528bb03 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -9,8 +9,8 @@ #include #include +#include #include -#include namespace torch::autograd { diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index 53a24eaa150dbe..b9387479667e86 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -122,7 +122,7 @@ using torch::profiler::impl::ProfilerStateBase; struct ProfilerLegacyThreadLocalState : public ProfilerStateBase { explicit ProfilerLegacyThreadLocalState( const torch::profiler::impl::ProfilerConfig& config) - : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {} + : ProfilerStateBase(config), remoteProfiledEvents_{c10::nullopt} {} ~ProfilerLegacyThreadLocalState() override = default; static ProfilerLegacyThreadLocalState* getTLS() { diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 59198129b2b278..9bd88b0b3dc51e 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -336,7 +336,7 @@ TORCH_API void enableProfilerLegacy( using thread_event_lists = std::vector>; TORCH_API thread_event_lists disableProfilerLegacy( std::optional profilerDisableOptions = - std::nullopt); + c10::nullopt); // adds profiledEvents to the current thread local recorded events. Each event // will be marked with node ID given by fromNodeId. @@ -377,9 +377,9 @@ struct TORCH_API TLSLegacyProfilerGuard { explicit TLSLegacyProfilerGuard( const torch::profiler::impl::ProfilerConfig& cfg, std::optional> - resultCallback = std::nullopt, + resultCallback = c10::nullopt, std::optional profilerDisableOptions = - std::nullopt) + c10::nullopt) : cb_(std::move(resultCallback)), profilerDisableOptions_(profilerDisableOptions) { enableProfilerLegacy(cfg); diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index e930faa1fdebe4..5fcc7b86a2fab8 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ #include #include #include -#include namespace py = pybind11; @@ -349,7 +349,7 @@ TensorMetadata toTensorMetadata(PyObject* self) { std::optional ValueCache::recordIfTensor(py::handle p) { return THPVariable_CheckExact(p.ptr()) ? std::optional{toTensorMetadata(p.ptr())} - : std::nullopt; + : c10::nullopt; } std::vector> ValueCache::unpackTensorMap( @@ -379,7 +379,7 @@ void ValueCache::store(const PyCallKey& key, no_ephemeral_t) { template <> ExtraFields::args_t ValueCache::load( const PyCallKey& key) const { - return {std::get(state_).at(key), std::nullopt}; + return {std::get(state_).at(key), c10::nullopt}; } template <> @@ -419,7 +419,7 @@ ExtraFields::args_t ValueCache::load( return { /*frame_state_=*/std::get(state_).at(*cache.location_), /*module_info_=*/std::move(info), - /*optimizer_info_=*/std::nullopt}; + /*optimizer_info_=*/c10::nullopt}; } template <> @@ -465,7 +465,7 @@ ExtraFields::args_t ValueCache::load< return { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) /*frame_state_=*/std::get(state_).at(*cache.location_), - /*module_info_=*/std::nullopt, + /*module_info_=*/c10::nullopt, /*optimizer_info_=*/std::move(info)}; } diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index a5ba07b2cdb53a..0227229d1f7fb9 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -778,7 +778,7 @@ static void _get_tensors_to_save( for (const auto i : c10::irange(num_saved)) { PyObject* obj = PyTuple_GET_ITEM(self->to_save, i); if (obj == Py_None) { - tensors_to_save.emplace_back(std::nullopt); + tensors_to_save.emplace_back(c10::nullopt); continue; } else if (THPVariable_Check(obj)) { const auto& tensor = THPVariable_Unpack(obj); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 0bf3c8bbab70b7..c2744f365476f0 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 94596c32a705e3..65f4b0efd3c188 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -347,7 +347,7 @@ bool isResurrectable(THPVariable* self) { // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false) != - std::make_optional((PyObject*)self)) { + c10::make_optional((PyObject*)self)) { return false; } return true; @@ -455,7 +455,7 @@ static int THPVariable_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false) == - std::make_optional((PyObject*)self)) { + c10::make_optional((PyObject*)self)) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -587,14 +587,14 @@ static PyObject* view_func_impl( auto& view_func = view_info.view_fn(); // Determine new SymInt / tensor state as needed. - std::optional> new_symints = std::nullopt; + std::optional> new_symints = c10::nullopt; if (symint_visitor_fn != Py_None) { new_symints = map_py_func( py::cast(symint_visitor_fn), view_func.get_symints()); } - std::optional> new_tensors = std::nullopt; + std::optional> new_tensors = c10::nullopt; if (tensor_visitor_fn != Py_None) { new_tensors = map_py_func( py::cast(tensor_visitor_fn), diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index e9b40b0dc8f75c..fdcafd6cd70910 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -100,7 +100,7 @@ static inline Variable sequenceToVariable( c10::TensorOptions options, PyObject* seq) { return torch::utils::indexing_tensor_from_data( - options, kLong, std::nullopt, seq); + options, kLong, c10::nullopt, seq); } inline Variable valueToTensor( @@ -201,7 +201,7 @@ static inline Variable applySlicing( // as null may need to be changed after we reach a better solution for // nested tensor size std::optional result_sizes = result.is_nested() - ? std::optional(std::nullopt) + ? std::optional(c10::nullopt) : std::optional(result.sym_sizes()); result = at::indexing::handleDimInMultiDimIndexing( /*prev_dim_result=*/result, diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h index a84d47c5b4829f..a145523c1bf8a5 100644 --- a/torch/csrc/autograd/record_function_ops.h +++ b/torch/csrc/autograd/record_function_ops.h @@ -1,7 +1,7 @@ #pragma once #include +#include #include -#include namespace torch::autograd::profiler { @@ -17,7 +17,7 @@ struct PythonRecordFunction : public torch::CustomClassHolder { // callbacks. TORCH_API c10::intrusive_ptr record_function_enter_new( const std::string& name, - const std::optional& args = std::nullopt); + const std::optional& args = c10::nullopt); // Schedules RecordFunction's end callbacks to be run on completion of a future. TORCH_API c10::intrusive_ptr _call_end_callbacks_on_fut_new( diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 7189e02047251d..1dad10663dd70b 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -67,7 +67,7 @@ inline at::Tensor clone_obey_contract( .new_empty_strided_symint( variable.sym_sizes(), variable.sym_strides(), - variable.options().memory_format(std::nullopt)) + variable.options().memory_format(c10::nullopt)) .copy_(new_grad)); } else { // (2) diff --git a/torch/csrc/autograd/utils/python_arg_parsing.h b/torch/csrc/autograd/utils/python_arg_parsing.h index e3fd671fb57cf9..326221e44d147a 100644 --- a/torch/csrc/autograd/utils/python_arg_parsing.h +++ b/torch/csrc/autograd/utils/python_arg_parsing.h @@ -31,7 +31,7 @@ parse_to_conversion(PythonArgs& r, bool allow_copy) { if (!allow_copy && !r.isNone(2)) throw std::runtime_error(".to() does not accept copy argument"); return std::make_tuple( - std::nullopt, + c10::nullopt, r.scalartype(0), r.toBool(1), r.toBool(2), diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 2ce91146dc8d06..d60f37085f3808 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -351,8 +351,8 @@ struct TORCH_API ViewFunc { /// Returns a clone of this ViewFunc, optionally with the specified saved /// state. virtual std::unique_ptr clone_and_set( - std::optional> = std::nullopt, - std::optional> = std::nullopt) const = 0; + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const = 0; protected: /// Sets the values of any SymInts in the saved state. The input vector size @@ -382,8 +382,8 @@ struct ChainedViewFunc : public ViewFunc { } virtual at::Tensor operator()(const at::Tensor&) const override; virtual std::unique_ptr clone_and_set( - std::optional> = std::nullopt, - std::optional> = std::nullopt) const override; + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const override; private: std::unique_ptr first; @@ -398,8 +398,8 @@ struct ErroringViewFunc : public ViewFunc { TORCH_CHECK(false, error_msg); } virtual std::unique_ptr clone_and_set( - std::optional> = std::nullopt, - std::optional> = std::nullopt) const override { + std::optional> = c10::nullopt, + std::optional> = c10::nullopt) const override { return std::make_unique(error_msg); } diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 52331909fe1dc2..d8f968eae5f5cb 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -11,9 +11,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h index 860629bcf2e9a3..92009a1c40ada5 100644 --- a/torch/csrc/cuda/comm.h +++ b/torch/csrc/cuda/comm.h @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include #include #include @@ -29,15 +29,15 @@ TORCH_CUDA_CU_API std::vector& scatter_out( std::vector& out_tensors, int64_t dim = 0, const std::optional>>& - streams = std::nullopt); + streams = c10::nullopt); TORCH_CUDA_CU_API std::vector scatter( const at::Tensor& tensor, at::IntArrayRef devices, - const std::optional>& chunk_sizes = std::nullopt, + const std::optional>& chunk_sizes = c10::nullopt, int64_t dim = 0, const std::optional>>& - streams = std::nullopt); + streams = c10::nullopt); TORCH_CUDA_CU_API at::Tensor& gather_out( at::TensorList tensors, diff --git a/torch/csrc/cuda/memory_snapshot.h b/torch/csrc/cuda/memory_snapshot.h index fe5699af416012..eb22767a78f905 100644 --- a/torch/csrc/cuda/memory_snapshot.h +++ b/torch/csrc/cuda/memory_snapshot.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include namespace torch::cuda { diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 6561ccb6e76c1a..37d1be15cbd701 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -2,9 +2,9 @@ #include #include +#include #include -#include #include // NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index f62311efbd9361..5060f9289a9e14 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -60,7 +60,7 @@ static std::vector> unpack_streams( PyObject* obj, size_t size) { if (obj == Py_None) { - return std::vector>(size, std::nullopt); + return std::vector>(size, c10::nullopt); } auto streams = THPUtils_PySequence_to_CUDAStreamList(obj); if (streams.size() != size) { diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index d37e695c77194a..062a15da4964c0 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -98,7 +98,7 @@ void DistEngine::globalCpuThread( InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { - inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); + inputs.add(i, std::move(variables[i]), c10::nullopt, c10::nullopt); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)), diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 98af0d51a3d050..6507fe6abc2a2b 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -18,7 +18,7 @@ namespace c10d { ncclComm_t NCCLComm::getNcclComm() { std::unique_lock lock(mutex_); if (aborted_) { - auto commFailureMsg = commFailureReason_ != std::nullopt + auto commFailureMsg = commFailureReason_ != c10::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) : ""; TORCH_CHECK_WITH( @@ -76,7 +76,7 @@ std::shared_ptr NCCLComm::split( C10D_NCCL_CHECK( ncclCommSplit( source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), - std::nullopt); + c10::nullopt); ++source->ncclCommSplitCounter_; comm->rank_ = rank; return comm; @@ -186,11 +186,11 @@ std::string ncclGetErrorWithVersion(ncclResult_t error) { // thrown in the NCCL codebase. std::string getNcclErrorDetailStr( ncclResult_t error, - std::optional processGroupFailureReason /* = std::nullopt */ + std::optional processGroupFailureReason /* = c10::nullopt */ ) { // Prioritize failure reason provided by PG NCCL first, as it can abort // communicators when it encounters collective timeouts, etc. - if (processGroupFailureReason != std::nullopt) { + if (processGroupFailureReason != c10::nullopt) { return *processGroupFailureReason; } std::string interpret; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 06568f6ce7d2f1..9ce25b55dc133a 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -11,8 +11,8 @@ #include #include +#include #include -#include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) @@ -183,7 +183,7 @@ bool shouldBroadcastNCCLUniqueID(bool isSendRecvSelf); // thrown in the NCCL codebase. TORCH_API std::string getNcclErrorDetailStr( ncclResult_t error, - std::optional processGroupFailureReason = std::nullopt); + std::optional processGroupFailureReason = c10::nullopt); // Write NCCL debug info to local disk or any storage users define. // There are some constrains we set for the debug info writer: @@ -221,7 +221,7 @@ class NCCLComm { : ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess), - commFailureReason_(std::nullopt), + commFailureReason_(c10::nullopt), initialized_(false) {} NCCLComm() : NCCLComm(nullptr) {} @@ -249,7 +249,7 @@ class NCCLComm { auto comm = std::make_shared(); C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), - std::nullopt); + c10::nullopt); comm->ncclId_ = commId; comm->rank_ = rank; comm->initialized_ = true; @@ -271,12 +271,12 @@ class NCCLComm { C10D_NCCL_CHECK_NONBLOCKING( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), - std::nullopt); + c10::nullopt); } else { C10D_NCCL_CHECK( ncclCommInitRankConfig( &(comm->ncclComm_), numRanks, commId, rank, &config), - std::nullopt); + c10::nullopt); // under blocking mode, comm is initialized after NCCL CHECK isInitialized = true; } @@ -301,7 +301,7 @@ class NCCLComm { LOG(INFO) << "Communicator was aborted before trying to dump its state."; return dump; } - C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); + C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt); return dump; } #endif @@ -336,7 +336,7 @@ class NCCLComm { } void ncclCommAbort( - std::optional commFailureReason = std::nullopt) { + std::optional commFailureReason = c10::nullopt) { std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index 23ee93b91d7a64..cff4ad09b70648 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -128,7 +128,7 @@ class TORCH_API ProcessGroupCudaP2P : public Backend { const BarrierOptions& opts = BarrierOptions()) override; c10::intrusive_ptr intra_node_barrier( - c10::optional> ranks = std::nullopt); + c10::optional> ranks = c10::nullopt); at::Tensor get_p2p_buffer( size_t rank, @@ -136,7 +136,7 @@ class TORCH_API ProcessGroupCudaP2P : public Backend { c10::ScalarType dtype, int64_t storage_offest = 0); - void shutdown(c10::optional reason = std::nullopt); + void shutdown(c10::optional reason = c10::nullopt); private: c10::intrusive_ptr nccl_backend_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index a6ed8fd26a161e..cba0249829e68b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -2425,7 +2425,7 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:scatter", !inputs.empty() ? std::optional>(inputs[0]) - : std::nullopt), + : c10::nullopt), context(context), outputs(outputs), inputs(inputs), @@ -2888,7 +2888,7 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { OpType::BARRIER, seq, "gloo:barrier", - std::nullopt), + c10::nullopt), context(context), priorWork(std::move(priorWork)), tag(tag) {} diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 9f1e63d58adf2d..87c87b8f1ae9bd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -74,7 +74,7 @@ class TORCH_API ProcessGroupGloo : public Backend { uint64_t seq, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - std::nullopt); + c10::nullopt); ~AsyncWork() override = default; diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index 91e9f938f1dd3e..6d02f89f6005b8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -673,7 +673,7 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) - : std::nullopt); + : c10::nullopt); } else { auto entry = std::make_unique( nullptr, &outputTensors, std::move(runFunc)); @@ -682,7 +682,7 @@ c10::intrusive_ptr ProcessGroupMPI::scatter( "mpi:scatter", !inputTensors.empty() ? std::optional>(inputTensors[0]) - : std::nullopt); + : c10::nullopt); } } @@ -932,7 +932,7 @@ c10::intrusive_ptr ProcessGroupMPI::barrier(const BarrierOptions& opts) { }; auto entry = std::make_unique(nullptr, nullptr, std::move(runFunc)); - return enqueue(std::move(entry), "mpi:barrier", std::nullopt); + return enqueue(std::move(entry), "mpi:barrier", c10::nullopt); } c10::intrusive_ptr ProcessGroupMPI::_allgather_base( diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp index 5eb06b7395570e..6e52e680e5c201 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp @@ -87,7 +87,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::vector outputTensors, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - std::nullopt) + c10::nullopt) : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), outputTensors_(std::move(outputTensors)), future_(c10::make_intrusive( @@ -115,7 +115,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::vector outputTensors, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - std::nullopt); + c10::nullopt); ~AsyncWork() override; @@ -244,7 +244,7 @@ class TORCH_API ProcessGroupMPI : public Backend { std::unique_ptr entry, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - std::nullopt); + c10::nullopt); bool stop_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index af940f53bf24c8..e7699b55245147 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -31,7 +32,6 @@ #include #include #include -#include namespace c10d { @@ -376,7 +376,7 @@ std::string dump_nccl_trace( bool includeStackTraces, bool onlyActive) { return NCCLTraceBuffer::get()->dump( - std::nullopt, includeCollectives, includeStackTraces, onlyActive); + c10::nullopt, includeCollectives, includeStackTraces, onlyActive); } #endif @@ -393,7 +393,7 @@ std::optional)>>& get_cpp_trace_dumper() { static std::optional< std::function)>> - dumper(std::nullopt); + dumper(c10::nullopt); return dumper; } @@ -658,7 +658,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( if (blockingWait_) { while (!isCompleted()) { bool timedOut = checkTimeout( - timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + timeout == kNoTimeout ? c10::nullopt : c10::make_optional(timeout)); // Explicitly abort ncclComms here before throwing this timed out // exception to users. // If throwing timed out excepiton without aborting nccl communicators @@ -1245,7 +1245,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { : heartbeatTimeoutInSec_ * 1000; auto lastTimePollStore = std::chrono::steady_clock::now(); auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now(); - std::optional dumpPipe = std::nullopt; + std::optional dumpPipe = c10::nullopt; if (uid_ == 0) { // DumpPipe is one per-trainer process, and its convenient to name them // after 'global' ranks in the system, So we assume processgroup (uid)==0 is @@ -1881,7 +1881,7 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( // Prioritize commFailureReason over checkForNcclError() result if // commFailureReason is set. auto commFailureReason = ncclComm->getNcclCommFailureReason(); - if (commFailureReason != std::nullopt) { + if (commFailureReason != c10::nullopt) { return std::make_exception_ptr(C10_BUILD_ERROR( DistBackendError, c10::str( @@ -2050,7 +2050,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool singleP2POp = isP2POp(opType, batchP2P); // For point-to-point communication, lower rank of the two will get unique id. if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { - C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), std::nullopt); + C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt); } if (shouldBroadcastNCCLUniqueID(isSendRecvSelf)) { @@ -2086,7 +2086,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( for (const auto i : c10::irange(ncclActiveGroupCounter_)) { (void)i; // comms have not been initiated yet, so can only check in blocking-way - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); } // GPU world size and GPU rank @@ -2182,7 +2182,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( // See [Group Start/End Note] for (const auto i : c10::irange(ncclActiveGroupCounter_)) { (void)i; - C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); } ncclStreams_.emplace(deviceKey, std::move(streamVal)); @@ -2334,7 +2334,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( seqCollective_, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) - : std::nullopt, + : c10::nullopt, desyncDebug_, enableTiming_.load(), dist_debug_level_); @@ -4190,23 +4190,23 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( } void ProcessGroupNCCL::groupStart() { - C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); ++ncclActiveGroupCounter_; } void ProcessGroupNCCL::groupEnd() { - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); --ncclActiveGroupCounter_; } void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { #ifndef NCCL_HAS_COMM_NONBLOCKING - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); #else if (!nccl_use_nonblocking()) { - C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); + C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); } else { - C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); + C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, c10::nullopt); } #endif --ncclActiveGroupCounter_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 763ef9829618f1..faaabe411bfccb 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -254,7 +254,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { OpType opType, uint64_t seq, const char* profilingTitle = nullptr, - const std::optional>& inputs = std::nullopt, + const std::optional>& inputs = c10::nullopt, bool desyncDebug = false, bool enableTiming = false, DebugLevel distDebugLevel = DebugLevel::Off); @@ -311,7 +311,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // and False otherwise. // In case of timeout, set exception on the WorkNCCL object. bool checkTimeout( - std::optional timeout = std::nullopt); + std::optional timeout = c10::nullopt); std::vector result() override; @@ -662,9 +662,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) // instead of relying on ProcessGroupNCCL destructor. // return true if abort is successful, otherwise false - bool abort(std::optional abortReason = std::nullopt); + bool abort(std::optional abortReason = c10::nullopt); - void shutdown(std::optional reason = std::nullopt); + void shutdown(std::optional reason = c10::nullopt); void eagerConnectSingleDevice(at::Device device) override; diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 2de969d135e8f8..fe24c31f9068bd 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -293,7 +293,7 @@ TCPStore::TCPStore( masterPort, isServer, numWorkers ? std::optional(*numWorkers) - : std::nullopt, + : c10::nullopt, waitWorkers, timeout}} {} @@ -376,7 +376,7 @@ TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() { detail::timing_guard tguard(clientCounters_["waitForWorkers"]); - if (numWorkers_ == std::nullopt) { + if (numWorkers_ == c10::nullopt) { return; } diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 9fd29b1c844cc6..25783f2d2acea9 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -49,7 +49,7 @@ struct TCPStoreOptions { std::uint16_t port = kDefaultPort; bool isServer = false; - std::optional numWorkers = std::nullopt; + std::optional numWorkers = c10::nullopt; bool waitWorkers = true; std::chrono::milliseconds timeout = Store::kDefaultTimeout; @@ -60,7 +60,7 @@ struct TCPStoreOptions { // If specified, and if isServer is true, the underlying TCPServer will take // over the bound socket associated to this fd. This option is useful to avoid // port assignment races in certain scenarios. - std::optional masterListenFd = std::nullopt; + std::optional masterListenFd = c10::nullopt; // A boolean value indicating whether to use the experimental libUV backend. bool useLibUV = true; @@ -73,7 +73,7 @@ class TORCH_API TCPStore : public Store { [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( const std::string& masterAddr, std::uint16_t masterPort, - std::optional numWorkers = std::nullopt, + std::optional numWorkers = c10::nullopt, bool isServer = false, const std::chrono::milliseconds& timeout = kDefaultTimeout, bool waitWorkers = true); diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 9ff71f9d41b848..de623d77fe9e0e 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -516,7 +516,7 @@ struct NCCLTraceBuffer { std::chrono::milliseconds timeout_ms, bool isP2P) { if (!enabled_) { - return std::nullopt; + return c10::nullopt; } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); @@ -621,7 +621,7 @@ struct NCCLTraceBuffer { bool can_compute_duration = false; Event* startEvent = nullptr; Event* endEvent = nullptr; - std::optional duration = std::nullopt; + std::optional duration = c10::nullopt; std::unique_lock guard(mutex_); diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 7cdb9f62ebbb85..669957a7267358 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -121,7 +121,7 @@ struct BroadcastOptions { struct AllreduceOptions { ReduceOp reduceOp = ReduceOp::SUM; std::chrono::milliseconds timeout = kUnsetTimeout; - std::optional sparseIndices = std::nullopt; + std::optional sparseIndices = c10::nullopt; }; struct AllreduceCoalescedOptions : AllreduceOptions {}; diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index b77a914da4e677..a03337e975148b 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -440,7 +440,7 @@ inline at::Tensor newLikeFlat( sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end()); strides.insert(strides.end(), t.strides().begin(), t.strides().end()); return at::empty_strided( - sizes, strides, t.options().memory_format(std::nullopt)); + sizes, strides, t.options().memory_format(c10::nullopt)); } inline at::Tensor newLikeFlat(std::vector& tensors) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index c10e5007b9f544..d29b838321176d 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -51,7 +51,7 @@ class TORCH_API Work : public torch::CustomClassHolder { OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr, const std::optional>& inputTensors = - std::nullopt); + c10::nullopt); ~Work() override; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 5145c969a95b00..6f1b28886b989b 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1415,7 +1415,7 @@ Example:: bool multiTenant, std::optional masterListenFd, bool useLibUV) { - std::optional numWorkers = std::nullopt; + std::optional numWorkers = c10::nullopt; if (worldSize.has_value() && worldSize.value() > -1) { numWorkers = static_cast(worldSize.value()); } @@ -2648,7 +2648,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::arg("store"), py::arg("rank"), py::arg("world_size"), - py::arg("buffer_size") = std::nullopt) + py::arg("buffer_size") = c10::nullopt) .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none()); #ifdef NCCL_HAS_COMM_CTA_CGA diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index b4d70f580da5cb..5d7e2d426d30a1 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -33,7 +33,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { c10::intrusive_ptr store, size_t rank, size_t worldSize, - std::optional bufferSize = std::nullopt); + std::optional bufferSize = c10::nullopt); ~IntraNodeComm() override; @@ -65,7 +65,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { /** * Perform a barrier among the specified ranks. */ - void barrier(std::optional> ranks = std::nullopt); + void barrier(std::optional> ranks = c10::nullopt); at::Tensor getBuffer( size_t rank, diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 48f8786842f01f..711039bf485954 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -234,7 +234,7 @@ void Logger::set_event_time( Timer& timer, Timer::Event event) { auto timestamp = timer.getTimestamp(event); - if (timestamp != std::nullopt) { + if (timestamp != c10::nullopt) { // TODO: should we set this as human-readable time instead of unixtime? event_time = *timestamp; } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 6c5f7a79ff9fbf..6a2812ab24b9cd 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -61,7 +61,7 @@ class CpuTimer : public Timer { // calculate the valid avg_time. // In this case, skip calculating the avg_time and return. if (end_time < start_time) { - return std::nullopt; + return c10::nullopt; } return end_time - start_time; } @@ -499,7 +499,7 @@ std::vector Reducer::get_grad_buckets( bucket.lengths, bucket.sizes_vec, variables_for_bucket, - std::nullopt); + c10::nullopt); } return gradBuckets; } @@ -1655,9 +1655,9 @@ void Reducer::finalize_backward() { } } - if (installed_futures_ != std::nullopt) { + if (installed_futures_ != c10::nullopt) { c10::collectAll(*installed_futures_)->wait(); - installed_futures_ = std::nullopt; + installed_futures_ = c10::nullopt; } // See Note [Skip allreducing local_used_maps_dev] diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index aa3c40ae95bbf2..1f72b0eb37b9f6 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -262,9 +262,9 @@ class TORCH_API Reducer { // List of futures installed by Reducer::install_futures that should be // awaited at the end of backwards pass. std::optional>> - installed_futures_{std::nullopt}; + installed_futures_{c10::nullopt}; // Mixed precision parameter dtype for bucket type checking. - std::optional mixed_precision_param_dtype_{std::nullopt}; + std::optional mixed_precision_param_dtype_{c10::nullopt}; // Work handle for allreduce on local_used_map_ c10::intrusive_ptr local_used_work_; @@ -389,7 +389,7 @@ class TORCH_API Reducer { bool expect_sparse_gradient = false; // Sparse indices tensor - std::optional sparse_tensor_indices = std::nullopt; + std::optional sparse_tensor_indices = c10::nullopt; // TODO(@pietern) // Memory copies from gradient tensors into the bucket are potentially diff --git a/torch/csrc/distributed/c10d/reducer_cuda.cpp b/torch/csrc/distributed/c10d/reducer_cuda.cpp index a158e44fc047c0..84bff02072b606 100644 --- a/torch/csrc/distributed/c10d/reducer_cuda.cpp +++ b/torch/csrc/distributed/c10d/reducer_cuda.cpp @@ -59,7 +59,7 @@ class CudaTimer : public Timer { // If it is never recorded/created, skip synchronize and calculation. // Otherwise it will throw cuda errors. if (!start_event.isCreated() || !end_event.isCreated()) { - return std::nullopt; + return c10::nullopt; } // set_runtime_stats_and_log is called at the beginning of forward call, // when it is cheap to synchronize the cuda events of previous iteration, @@ -74,7 +74,7 @@ class CudaTimer : public Timer { // calculate the valid avg_time. // In this case, skip calculating the avg_time and return. if (milliseconds < 0) { - return std::nullopt; + return c10::nullopt; } return int64_t(milliseconds * kMilliSecondToNanosSecond); } diff --git a/torch/csrc/distributed/c10d/reducer_timer.hpp b/torch/csrc/distributed/c10d/reducer_timer.hpp index dbea3958db43da..f9b9f11c8c9632 100644 --- a/torch/csrc/distributed/c10d/reducer_timer.hpp +++ b/torch/csrc/distributed/c10d/reducer_timer.hpp @@ -47,7 +47,7 @@ class TORCH_API Timer { std::optional getTimestamp(Event event) { auto time = getTimeRef(event); if (time == kUnsetTime) { - return std::nullopt; + return c10::nullopt; } else { return time; } diff --git a/torch/csrc/distributed/c10d/sequence_num.cpp b/torch/csrc/distributed/c10d/sequence_num.cpp index 3807d629d830c5..fd76247199f618 100644 --- a/torch/csrc/distributed/c10d/sequence_num.cpp +++ b/torch/csrc/distributed/c10d/sequence_num.cpp @@ -10,7 +10,7 @@ SequenceNum::SequenceNum(const uint64_t num) : num_(num) {} SequenceNum::SequenceNum(const SequenceNum& other) { if (!other.isSet()) { - num_ = std::nullopt; + num_ = c10::nullopt; } else { num_ = other.get(); } @@ -23,7 +23,7 @@ uint64_t SequenceNum::get() const { void SequenceNum::increment() { std::lock_guard lock(lock_); - TORCH_CHECK(num_ != std::nullopt); + TORCH_CHECK(num_ != c10::nullopt); num_ = ++(*num_); } @@ -32,7 +32,7 @@ void SequenceNum::increment() { uint64_t SequenceNum::getAndIncrement() { uint64_t curVal = 0; std::lock_guard lock(lock_); - TORCH_CHECK(num_ != std::nullopt); + TORCH_CHECK(num_ != c10::nullopt); curVal = *num_; num_ = ++(*num_); return curVal; @@ -45,13 +45,13 @@ void SequenceNum::set(const uint64_t num) { bool SequenceNum::isSet() const { std::lock_guard lock(lock_); - return num_ != std::nullopt; + return num_ != c10::nullopt; } SequenceNum& SequenceNum::operator=(const SequenceNum& other) { std::lock_guard lock(lock_); if (!other.isSet()) { - num_ = std::nullopt; + num_ = c10::nullopt; } else { num_ = other.get(); } diff --git a/torch/csrc/distributed/c10d/sequence_num.hpp b/torch/csrc/distributed/c10d/sequence_num.hpp index 38bd4cb5ed9d38..ce31f4b5527282 100644 --- a/torch/csrc/distributed/c10d/sequence_num.hpp +++ b/torch/csrc/distributed/c10d/sequence_num.hpp @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include namespace c10d { diff --git a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp index eb45679873f039..3a37e7b02a5f05 100644 --- a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp +++ b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.cpp @@ -10,7 +10,7 @@ namespace rpc { const std::string REMOTE_PROFILING_KEY_PREFIX = "#remote_op: "; constexpr int kAutoIncrementBits = 48; /*static */ thread_local std::optional - RemoteProfilerManager::currentThreadLocalKey_ = std::nullopt; + RemoteProfilerManager::currentThreadLocalKey_ = c10::nullopt; /*static */ RemoteProfilerManager& RemoteProfilerManager::getInstance() { static RemoteProfilerManager* handler = new RemoteProfilerManager(); return *handler; @@ -32,7 +32,7 @@ bool RemoteProfilerManager::isCurrentKeySet() const { } void RemoteProfilerManager::unsetCurrentKey() { - currentThreadLocalKey_ = std::nullopt; + currentThreadLocalKey_ = c10::nullopt; } void RemoteProfilerManager::eraseKey(const ProfilingId& globallyUniqueId) { diff --git a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h index 2889120b67ca69..c6f8b353806b5b 100644 --- a/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h +++ b/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include #include -#include #include namespace torch { diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp index 887f25b6c16dd4..ed7847a1f5faa2 100644 --- a/torch/csrc/distributed/rpc/py_rref.cpp +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -119,7 +119,7 @@ TypePtr tryInferTypeWithTypeHint( /////////////////////////// PyRRef ////////////////////////////////// PyRRef::PyRRef(c10::intrusive_ptr rref) - : rref_(std::move(rref)), profilingFuture_(std::nullopt) { + : rref_(std::move(rref)), profilingFuture_(c10::nullopt) { TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); C10_LOG_API_USAGE_ONCE("torch.distributed.rref"); } diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 51ee554abda743..57acbc0370252f 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -261,7 +261,7 @@ c10::intrusive_ptr pyRpcTorchscript( functionSchema, argsTuple.cast(), kwargsDict.cast(), - std::nullopt); + c10::nullopt); } DCHECK(!PyGILState_Check()); c10::intrusive_ptr fut = rpcTorchscript( @@ -408,7 +408,7 @@ PyRRef pyRemoteTorchscript( // Acquire GIL for py::args and py::kwargs processing. py::gil_scoped_acquire ag; stack = torch::jit::createStackForSchema( - functionSchema, args, kwargs, std::nullopt); + functionSchema, args, kwargs, c10::nullopt); } DCHECK(!PyGILState_Check()); auto rrefPtr = remoteTorchscript( diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 3b6b04047c4e0c..fb73cf2abf483e 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -440,7 +440,7 @@ c10::intrusive_ptr RequestCallbackNoPython:: true /* cleanup TLS state */, false /* consolidate events */); { TLSLegacyProfilerGuard g( - profilingConfig, std::nullopt, requestThreadOptions); + profilingConfig, c10::nullopt, requestThreadOptions); TORCH_INTERNAL_ASSERT( profilerEnabled(), "Expected profiler to be enabled!"); // Kick off processing for nested work and get Future result in diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 507d6bc846587c..d6da3f2ea455f0 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -3,10 +3,10 @@ #include #include #include +#include #include #include #include -#include #include diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 5db4adf95f85ba..dacded5cc1e62a 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -1,10 +1,10 @@ #pragma once +#include #include #include #include #include -#include #include namespace torch { diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 8259efeee1f9b3..50cc97785f61da 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -94,7 +94,7 @@ class TensorpipeCudaConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return std::nullopt; + return c10::nullopt; } at::DataPtr allocateTensorForReceiving( diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 9d38b5538d554a..929ae30f8a6d4d 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -59,7 +59,7 @@ class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return std::make_optional(std::move(storageData)); + return c10::make_optional(std::move(storageData)); } else { tensorpipe::CpuBuffer buffer; buffer.ptr = static_cast(storage.mutable_data()); @@ -70,7 +70,7 @@ class TensorpipeCpuConverter : public TensorpipeDeviceTypeConverter { message.tensors.push_back(std::move(tensor)); - return std::nullopt; + return c10::nullopt; } } diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 7913caad5449fc..2e5cb3bfab02e1 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -591,7 +591,7 @@ CacheNode* _compiled_autograd_impl( if (next.is_valid() && output.defined()) { input_buffers.lookup(next.function.get()) .add( - next.input_nr, std::move(output), std::nullopt, std::nullopt); + next.input_nr, std::move(output), c10::nullopt, c10::nullopt); } } } diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index b54a7285f63549..53da5a634746c2 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -242,7 +242,7 @@ int64_t _grad_increment_nesting() { // See NOTE [grad and vjp interaction with no_grad] bool prev_grad_mode = c10::GradMode::is_enabled(); return initAndPushDynamicLayer( - TransformType::Grad, std::nullopt, std::nullopt, prev_grad_mode); + TransformType::Grad, c10::nullopt, c10::nullopt, prev_grad_mode); } int64_t _grad_decrement_nesting() { @@ -257,9 +257,9 @@ int64_t _jvp_increment_nesting() { c10::AutogradState::get_tls_state().get_fw_grad_mode(); return initAndPushDynamicLayer( TransformType::Jvp, - std::nullopt, - std::nullopt, - std::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, prev_fwd_grad_mode); } @@ -287,10 +287,10 @@ int64_t _vmap_decrement_nesting() { int64_t _func_increment_nesting(bool reapply_views) { return initAndPushDynamicLayer( TransformType::Functionalize, - std::nullopt, - std::nullopt, - std::nullopt, - std::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, /*functionalize_add_back_views=*/reapply_views); } @@ -528,7 +528,7 @@ void initFuncTorchBindings(PyObject* module) { "get_interpreter_stack", []() -> std::optional> { const auto& stack = getDynamicLayerStack(); if (stack.empty()) { - return std::nullopt; + return c10::nullopt; } std::vector result; result.reserve(stack.size()); @@ -540,7 +540,7 @@ void initFuncTorchBindings(PyObject* module) { m.def("peek_interpreter_stack", []() -> std::optional { const auto& stack = getDynamicLayerStack(); if (stack.empty()) { - return std::nullopt; + return c10::nullopt; } auto result = stack.back().interpreter(); return result; diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index eca21f6bf348c4..6e7bd355c57c31 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -7,9 +7,9 @@ #include #include #include +#include #include #include -#include #define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ try { \ @@ -66,41 +66,41 @@ inline void assert_inf_and_nan( // utility functions to convert a pointer to an optional value template inline std::optional pointer_to_optional(T* ptr) { - return ptr ? std::make_optional(*ptr) : std::nullopt; + return ptr ? c10::make_optional(*ptr) : c10::nullopt; } template >> inline std::optional pointer_to_optional(U* ptr) { - return ptr ? std::make_optional(T(*ptr)) : std::nullopt; + return ptr ? c10::make_optional(T(*ptr)) : c10::nullopt; } template <> inline std::optional pointer_to_optional(AtenTensorHandle* ptr) { - return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) - : std::nullopt; + return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : c10::nullopt; } template <> inline std::optional pointer_to_optional( const AtenTensorHandle* ptr) { - return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) - : std::nullopt; + return ptr ? c10::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : c10::nullopt; } template <> inline std::optional pointer_to_optional( AtenGeneratorHandle* ptr) { - return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr)) - : std::nullopt; + return ptr ? c10::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : c10::nullopt; } inline std::optional pointer_to_optional_device( int32_t* device_type, int32_t device_index) { - return device_type ? std::make_optional(c10::Device( + return device_type ? c10::make_optional(c10::Device( static_cast(*device_type), static_cast(device_index))) - : std::nullopt; + : c10::nullopt; } // utility functions to convert a pointer to a list @@ -180,8 +180,8 @@ inline std::optional> pointer_to_optional_list( U** ptr, int64_t len) { return ptr - ? std::make_optional>(pointer_to_list(*ptr, len)) - : std::nullopt; + ? c10::make_optional>(pointer_to_list(*ptr, len)) + : c10::nullopt; } } // namespace torch::aot_inductor diff --git a/torch/csrc/jit/api/compilation_unit.h b/torch/csrc/jit/api/compilation_unit.h index d1c2c829d660c3..8e28ef4717b934 100644 --- a/torch/csrc/jit/api/compilation_unit.h +++ b/torch/csrc/jit/api/compilation_unit.h @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include @@ -97,7 +97,7 @@ struct TORCH_API CompilationUnit { const Self* self, // see [name mangling] bool shouldMangle = false, - std::optional operator_set_version = std::nullopt); + std::optional operator_set_version = c10::nullopt); void define_hooks( const std::optional& prefix, @@ -293,7 +293,7 @@ struct TORCH_API CompilationUnit { const std::unordered_map& function_table, bool shouldMangle = false, FunctionType type = FunctionType::Method, - std::optional version = std::nullopt) const; + std::optional version = c10::nullopt) const; // Define a property on \p self. struct PropertyPair; diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index 01e7a3c98e3024..6ed8cb36199ef2 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -13,7 +13,7 @@ struct TORCH_API GraphFunction : public Function { std::shared_ptr graph, std::function function_creator, std::optional executor_execution_mode = - std::nullopt) + c10::nullopt) : name_(std::move(name)), graph_(std::move(graph)), executor_execution_mode_(executor_execution_mode), diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index ae878376bab318..45b99eb8e47aa6 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -158,11 +158,11 @@ void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) { } void Module::to(at::ScalarType dtype, bool non_blocking) { - to_impl(/*device=*/std::nullopt, dtype, non_blocking); + to_impl(/*device=*/c10::nullopt, dtype, non_blocking); } void Module::to(at::Device device, bool non_blocking) { - to_impl(device, /*dtype=*/std::nullopt, non_blocking); + to_impl(device, /*dtype=*/c10::nullopt, non_blocking); } static void module_state_to( diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 9b2648737b0ce0..92b9c96c3a6ecf 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -15,8 +15,8 @@ #include #include #include +#include #include -#include #include #include @@ -238,7 +238,7 @@ struct TORCH_API Module : public Object { Module copy() const; - Module deepcopy(std::optional device = std::nullopt) const; + Module deepcopy(std::optional device = c10::nullopt) const; // Clones both the underlying `ClassType` and the module instance(data), this // function creates a new `ClassType` and returns a new instance that has the @@ -334,7 +334,7 @@ struct TORCH_API Module : public Object { TORCH_API Module freeze( const Module& module, const std::optional>& preserved_attrs = - std::nullopt, + c10::nullopt, bool optimize_numerics = true); // C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation @@ -552,7 +552,7 @@ struct slot_list_impl { : module_(std::move(module)), recurse_(recurse), return_module_(return_module), - size_(std::nullopt) { + size_(c10::nullopt) { if (!recurse && !return_module && Policy::all_slots) { size_ = module_.num_slots(); } diff --git a/torch/csrc/jit/api/object.cpp b/torch/csrc/jit/api/object.cpp index f95d576d6c8cb2..b707e767727650 100644 --- a/torch/csrc/jit/api/object.cpp +++ b/torch/csrc/jit/api/object.cpp @@ -20,7 +20,7 @@ std::optional Object::find_method(const std::string& basename) const { return Method(_ivalue(), fn); } } - return std::nullopt; + return c10::nullopt; } void Object::define(const std::string& src, const ResolverPtr& resolver) { diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 2c0f7e3b164f05..164f6e2ac073af 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -2,8 +2,8 @@ #include #include +#include #include -#include #include @@ -129,7 +129,7 @@ struct TORCH_API Object { const Property get_property(const std::string& name) const { for (const auto& prop : type()->properties()) { if (prop.name == name) { - std::optional setter = std::nullopt; + std::optional setter = c10::nullopt; if (prop.setter) { setter = Method(_ivalue(), prop.setter); } @@ -142,7 +142,7 @@ struct TORCH_API Object { const std::vector get_properties() const { return c10::fmap(type()->properties(), [&](ClassType::Property prop) { - std::optional setter = std::nullopt; + std::optional setter = c10::nullopt; if (prop.setter) { setter = Method(_ivalue(), prop.setter); } diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 7e03b576d12184..b4bc3e8f4727e3 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -231,7 +231,7 @@ std::shared_ptr compileKernel( size_t input_index = 0; for (const auto& p : graph->inputs()) { if (p->type()->isSubtypeOf(*FloatType::get())) { - flat_inputs.emplace_back(p, std::nullopt); + flat_inputs.emplace_back(p, c10::nullopt); } if (!p->type()->isSubtypeOf(*TensorType::get())) { continue; diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index db9d57a679cb15..5f692d50e6b54e 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -3,9 +3,9 @@ #include #include #include +#include #include #include -#include #include #include @@ -65,7 +65,7 @@ std::optional exec(const std::wstring& cmd) { std::unique_ptr pipe( _wpopen(cmd.c_str(), L"r"), _pclose); if (!pipe) { - return std::nullopt; + return c10::nullopt; } while (fgetws(buffer.data(), static_cast(buffer.size()), pipe.get()) != nullptr) { diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 411dbe62a2e157..8abb99283ffc75 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -11,7 +12,6 @@ #include #include #include -#include #include #include // TODO: remove, debugging only @@ -44,7 +44,7 @@ static std::optional> getMapSize( try { map_size = at::infer_size(map_size, arg.sizes()); } catch (...) { - return std::nullopt; + return c10::nullopt; } } else { auto tensor_sizes = arg.sizes().vec(); @@ -52,13 +52,13 @@ static std::optional> getMapSize( const auto dim = at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size()); if (tensor_sizes[dim] % num_chunks != 0) { - return std::nullopt; + return c10::nullopt; } tensor_sizes[dim] /= num_chunks; try { map_size = at::infer_size(map_size, tensor_sizes); } catch (...) { - return std::nullopt; + return c10::nullopt; } } } @@ -83,12 +83,12 @@ static std::optional> canRunKernel( if (!map_size) { map_size = getMapSize(spec, args, broadcast_group); if (!map_size) - return std::nullopt; + return c10::nullopt; } else { const auto group_map_size = getMapSize(spec, args, broadcast_group); // Note: this checks that group_map_size is defined AND equal to map_size if (map_size != group_map_size) - return std::nullopt; + return c10::nullopt; } } diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index eacdbc7ec3f336..2fc52f2d76f0f2 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -2,13 +2,13 @@ #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -122,7 +122,7 @@ struct TORCH_API KernelSpec { std::lock_guard guard{mutex_}; const auto it = kernels_.find(arg_spec); if (it == kernels_.end()) - return std::nullopt; + return c10::nullopt; return it->second; } void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr kernel) diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index 30f32f5994c1d4..16484dd4653c8f 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -26,7 +26,7 @@ static std::optional getDimensions(Value* v) { if (v->type()->isSubtypeOf(TensorType::get())) { return v->type()->cast()->sizes().size(); } else { - return std::nullopt; + return c10::nullopt; } } diff --git a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp index 71e74501656913..dfbfe467e9765b 100644 --- a/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp @@ -132,7 +132,7 @@ std::optional GraphRewriter::tryMerge(Node* consumer, Node* producer) { bool canMerge = llgaHelper_.shouldMerge(producer, consumer) && aliasDb_.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { - return std::nullopt; + return c10::nullopt; } llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_); return consumer; diff --git a/torch/csrc/jit/codegen/onednn/prepare_binary.cpp b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp index d09b5777f97347..a4f6d268694e36 100644 --- a/torch/csrc/jit/codegen/onednn/prepare_binary.cpp +++ b/torch/csrc/jit/codegen/onednn/prepare_binary.cpp @@ -69,7 +69,7 @@ static void handleBinaryOpInputs(Node* node) { auto second_input_typeptr = node->input(1)->type()->expect(); std::optional second_input_type = second_input_typeptr->scalarType(); - if (second_input_type != std::nullopt) { + if (second_input_type != c10::nullopt) { // dtype of the second tensor might not be available in the IR auto dtypeOfSecondInput = second_input_type.value(); if (dtypeOfFirstInput != dtypeOfSecondInput) { diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h index edac94a7357bf8..80b2e2a82f788f 100644 --- a/torch/csrc/jit/cuda/cuda.h +++ b/torch/csrc/jit/cuda/cuda.h @@ -15,7 +15,7 @@ class CUDAStream final : public CustomClassHolder { public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream( - std::optional device = std::nullopt, + std::optional device = c10::nullopt, int64_t priority = 0) { c10::DeviceIndex device_index = device.has_value() ? device->index() : c10::cuda::current_device(); @@ -157,7 +157,7 @@ TORCH_LIBRARY(cuda, m) { auto stream_class = m.class_("Stream").def( torch::init, int64_t>(), "", - {torch::arg("device") = std::nullopt, torch::arg("priority") = 0}); + {torch::arg("device") = c10::nullopt, torch::arg("priority") = 0}); auto event_class = m.class_("Event").def( torch::init(), "", diff --git a/torch/csrc/jit/frontend/builtin_functions.cpp b/torch/csrc/jit/frontend/builtin_functions.cpp index 2b3bdc42e4cc1e..c1c1d87176b759 100644 --- a/torch/csrc/jit/frontend/builtin_functions.cpp +++ b/torch/csrc/jit/frontend/builtin_functions.cpp @@ -121,7 +121,7 @@ struct BuiltinFunctionRegistry { void loadSource(const std::string& source, const std::string& the_namespace) { std::shared_ptr cu = std::make_shared(); modules.emplace_back(cu); - cu->define(std::nullopt, source, nativeResolver(), /*self=*/nullptr); + cu->define(c10::nullopt, source, nativeResolver(), /*self=*/nullptr); for (auto& method : cu->get_functions()) { builtins_by_name_[Symbol::fromQualString( the_namespace + "::" + method->name())] diff --git a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp index f2ef8b0e953c4b..943551e80692f1 100644 --- a/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp +++ b/torch/csrc/jit/frontend/canonicalize_modified_loop.cpp @@ -28,7 +28,7 @@ static void canonicalizeModifiedLoop(Node* n) { g->insertConstant(std::numeric_limits::max())); auto inp_condition = toIValue(loop.inputCond()); - if (inp_condition == std::nullopt || inp_condition->toBool() == false) { + if (inp_condition == c10::nullopt || inp_condition->toBool() == false) { condition = g->insert(aten::__and__, {condition, loop.inputCond()}); } loop.replaceInputCondition(condition); diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index cfdef51afc31c1..c15116ac3e2446 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -151,7 +151,7 @@ TypePtr ConcreteModuleType::getJitType() const { std::optional ConcreteModuleType::getPyClass() const { if (!data_.pyClass_) { - return std::nullopt; + return c10::nullopt; } return data_.pyClass_; } @@ -162,7 +162,7 @@ std::optional> ConcreteModuleType::findOverloads( if (it != data_.overloads_.end()) { return it->second; } - return std::nullopt; + return c10::nullopt; } std::optional ConcreteModuleType::findFunctionAttribute( @@ -171,7 +171,7 @@ std::optional ConcreteModuleType::findFunctionAttribute( if (it != data_.functionAttributes_.end()) { return it->second.function_->function(); } - return std::nullopt; + return c10::nullopt; } std::optional ConcreteModuleType::findBuiltinFunction( @@ -180,7 +180,7 @@ std::optional ConcreteModuleType::findBuiltinFunction( if (it != data_.builtinFunctions_.end()) { return it->second; } - return std::nullopt; + return c10::nullopt; } std::optional ConcreteModuleType::findFailedAttribute( @@ -189,7 +189,7 @@ std::optional ConcreteModuleType::findFailedAttribute( if (it != data_.failedAttributes_.end()) { return it->second; } - return std::nullopt; + return c10::nullopt; } bool ConcreteModuleType::isIgnoredAttribute(const std::string& name) const { diff --git a/torch/csrc/jit/frontend/function_schema_parser.cpp b/torch/csrc/jit/frontend/function_schema_parser.cpp index 00ccce8567fb61..ba86a891d31dd0 100644 --- a/torch/csrc/jit/frontend/function_schema_parser.cpp +++ b/torch/csrc/jit/frontend/function_schema_parser.cpp @@ -2,10 +2,10 @@ #include #include +#include #include #include #include -#include #include #include @@ -25,7 +25,7 @@ struct SchemaParser { explicit SchemaParser(const std::string& str, bool allow_typevars) : L(std::make_shared( c10::string_view(str), - std::nullopt, + c10::nullopt, 0, nullptr, Source::DONT_COPY)), diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 788483aef224ff..350305b83567c8 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -32,8 +32,8 @@ #include +#include #include -#include #include #include @@ -292,7 +292,7 @@ struct Environment { if (msg != runner->error_messages.end()) { return msg->second(); } else { - return std::nullopt; + return c10::nullopt; } } @@ -1267,7 +1267,7 @@ struct to_ir { {}); auto refinements = RefinementSet(findIsNoneRefinements( cond_op.lhs(), lhs_val, cond_op.rhs(), rhs_val, expr.kind())); - return CondValue(cond_value, refinements, std::nullopt); + return CondValue(cond_value, refinements, c10::nullopt); } } break; default: { @@ -1294,7 +1294,7 @@ struct to_ir { } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); - std::optional static_if = std::nullopt; + std::optional static_if = c10::nullopt; auto kind = expr_out->node()->kind(); if (kind == aten::is_scripting) { static_if = true; @@ -2291,7 +2291,7 @@ struct to_ir { Value* result = graph->insertNode(graph->createIsInstance(lhs_val, rhs_types)) ->output(); - return CondValue(result, std::move(refinement), std::nullopt); + return CondValue(result, std::move(refinement), c10::nullopt); } void emitIf(const If& stmt) { @@ -2752,7 +2752,7 @@ struct to_ir { getAugOp(stmt, lhs->type()), /*args=*/{lhs, rhs}, /*kwargs=*/{}, - /*self=*/std::nullopt); + /*self=*/c10::nullopt); } } @@ -2968,7 +2968,7 @@ struct to_ir { auto outputs = rhs_output->asTuple( rhs_loc, method, - starred_unpack ? std::nullopt : std::optional{n_binders}); + starred_unpack ? c10::nullopt : std::optional{n_binders}); if (outputs.size() < n_binders) { throw ErrorReport(tl) << "need " << (starred_unpack ? "at least " : "") << n_binders @@ -4796,11 +4796,11 @@ struct to_ir { tuple_args.reserve(3); start ? tuple_args.emplace_back(start) - : tuple_args.emplace_back(std::nullopt); + : tuple_args.emplace_back(c10::nullopt); end ? tuple_args.emplace_back(end) - : tuple_args.emplace_back(std::nullopt); + : tuple_args.emplace_back(c10::nullopt); step ? tuple_args.emplace_back(step) - : tuple_args.emplace_back(std::nullopt); + : tuple_args.emplace_back(c10::nullopt); return emitTupleSlice(loc, args[0], tuple_args); } @@ -4886,7 +4886,7 @@ struct to_ir { }; std::vector dims(subscript_exprs.size()); std::vector> exprs( - subscript_exprs.size(), std::nullopt); + subscript_exprs.size(), c10::nullopt); auto handle_indexing = [&](const Expr& subscript_expr, int expr_idx, @@ -5231,7 +5231,7 @@ struct to_ir { val_range, "begin", emitExpr(Expr(slice.start().get()))); tuple_args.emplace_back(begin); } else { - tuple_args.emplace_back(std::nullopt); + tuple_args.emplace_back(c10::nullopt); } if (slice.end().present()) { @@ -5239,7 +5239,7 @@ struct to_ir { NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); tuple_args.emplace_back(end); } else { - tuple_args.emplace_back(std::nullopt); + tuple_args.emplace_back(c10::nullopt); } if (slice.step().present()) { @@ -5247,7 +5247,7 @@ struct to_ir { NamedValue(val_range, "step", emitExpr(Expr(slice.step().get()))); tuple_args.emplace_back(step); } else { - tuple_args.emplace_back(std::nullopt); + tuple_args.emplace_back(c10::nullopt); } auto tupleSliceValue = emitTupleSlice(val_range, s_tuple_val, tuple_args); @@ -5327,7 +5327,7 @@ struct FunctionResolver : public Resolver { CompilationUnit::CompilationUnit(const std::string& source) : CompilationUnit() { // calles the define with native resolver to generate the graph for functions - define(std::nullopt, source, nativeResolver(), nullptr); + define(c10::nullopt, source, nativeResolver(), nullptr); } // This pair represents a pair of functions (getter and setter) obtained from diff --git a/torch/csrc/jit/frontend/parse_string_literal.h b/torch/csrc/jit/frontend/parse_string_literal.h index 13bbbf89cc343f..5b924864bebd8a 100644 --- a/torch/csrc/jit/frontend/parse_string_literal.h +++ b/torch/csrc/jit/frontend/parse_string_literal.h @@ -1,7 +1,7 @@ #pragma once +#include #include #include -#include namespace torch { namespace jit { @@ -15,17 +15,17 @@ inline bool isCharCount(char c, const std::string& str, size_t start, int len) { inline std::optional parseOctal(const std::string& str, size_t pos) { //\xxx where x are 0-7 if (pos + 3 >= str.size()) - return std::nullopt; + return c10::nullopt; size_t c = 0; for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) { // NOLINTNEXTLINE(bugprone-signed-char-misuse) int d = str[pos + i]; if (d < '0' || d > '7') - return std::nullopt; + return c10::nullopt; c += b * (d - '0'); } if (c >= 256) - return std::nullopt; + return c10::nullopt; return c; } diff --git a/torch/csrc/jit/frontend/parser.cpp b/torch/csrc/jit/frontend/parser.cpp index 5bf6144d8c7d5b..ae2c98028e0717 100644 --- a/torch/csrc/jit/frontend/parser.cpp +++ b/torch/csrc/jit/frontend/parser.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include namespace torch::jit { @@ -241,7 +241,7 @@ struct ParserImpl { return create_compound('=', r, {}); // no reduction } break; default: - return std::nullopt; + return c10::nullopt; } } TreeRef parseTrinary( diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index a91f204a404cfd..87ec9992141d89 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -12,7 +13,6 @@ #include #include #include -#include namespace torch::jit { @@ -261,7 +261,7 @@ std::optional findInputWithName( return i; } } - return std::nullopt; + return c10::nullopt; } /// Creates a list with the provided values if each value's type can be matched @@ -364,7 +364,7 @@ static std::optional tryMatchSchema( std::ostream* failure_messages, bool allow_conversions) { if (isBlockListedSchema(schema)) { - return std::nullopt; + return c10::nullopt; } auto err = [&]() -> std::ostream& { @@ -392,7 +392,7 @@ static std::optional tryMatchSchema( std::optional actual_named_value; if (arg.name() == "self" && self) { actual_named_value = self; - self = std::nullopt; + self = c10::nullopt; } else if (!arg.kwarg_only() && used_args < args.size()) { // Try to convert all the remaining non-kwarg arguments (used_args) to a // list. Allow zeros(IntArrayRef sizes) to work with zeros(1, 2) or @@ -417,7 +417,7 @@ static std::optional tryMatchSchema( allow_conversions, type_env); if (!list) { - return std::nullopt; + return c10::nullopt; } used_args = args.size(); positional_inputs.push_back(list); @@ -437,7 +437,7 @@ static std::optional tryMatchSchema( err() << "Argument " << nv.name() << " specified twice in schema, submit a bug report!\n"; } - return std::nullopt; + return c10::nullopt; } used_kwarg[*kwarg_idx] = true; actual_named_value = nv; @@ -450,7 +450,7 @@ static std::optional tryMatchSchema( err() << "Argument " << schema.arguments()[schema_i].name() << " not provided.\n"; } - return std::nullopt; + return c10::nullopt; } // Make sure the actual_named_value found matches the type of arg @@ -464,16 +464,16 @@ static std::optional tryMatchSchema( allow_conversions, type_env); if (!positional) { - return std::nullopt; + return c10::nullopt; } positional_inputs.push_back(positional); } // check for unused self argument - if (self != std::nullopt) { + if (self != c10::nullopt) { if (failure_messages) { err() << "Provided self argument not used in schema.\n"; } - return std::nullopt; + return c10::nullopt; } if (schema.is_vararg()) { @@ -488,7 +488,7 @@ static std::optional tryMatchSchema( err() << "Expected at most " << used_args << " arguments " << "but found " << args.size() << " positional arguments.\n"; } - return std::nullopt; + return c10::nullopt; } // check for unused kwargs for (const auto i : c10::irange(kwargs.size())) { @@ -501,7 +501,7 @@ static std::optional tryMatchSchema( err() << "Keyword argument " << nv.name() << " specified twice.\n"; } } - return std::nullopt; + return c10::nullopt; } } @@ -518,7 +518,7 @@ static std::optional tryMatchSchema( std::all_of(returns.begin(), returns.end(), [&](const Argument& r) { return r.name().length() > 0; }); - c10::OptNameList return_field_names = std::nullopt; + c10::OptNameList return_field_names = c10::nullopt; if (return_has_field_names) { return_field_names = fmap(returns, [&](const Argument& r) { return r.name(); }); @@ -633,7 +633,7 @@ static Value* packOutputs( if (field_names) { auto types = fmap(values, [](Value* v) { return v->type(); }); named_tuple = - TupleType::createNamed(std::nullopt, field_names.value(), types); + TupleType::createNamed(c10::nullopt, field_names.value(), types); } return g.insertNode(g.createTuple(values, named_tuple))->output(); } diff --git a/torch/csrc/jit/frontend/schema_matching.h b/torch/csrc/jit/frontend/schema_matching.h index 8a24863cbe71d0..0c69df521df6a2 100644 --- a/torch/csrc/jit/frontend/schema_matching.h +++ b/torch/csrc/jit/frontend/schema_matching.h @@ -10,7 +10,7 @@ namespace jit { // Try to match a list of inputs and keyword 'attributes' to this // schema. Return the flat list of positional inputs to the call or -// `std::nullopt` on failure (`failure_messages` contains a good error +// `c10::nullopt` on failure (`failure_messages` contains a good error // report in this case) struct MatchedSchema { @@ -28,7 +28,7 @@ TORCH_API MatchedSchema matchSchema( Graph& graph, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = std::nullopt); + const std::optional& self = c10::nullopt); TORCH_API std::pair matchSchemas( const std::vector& schemas, @@ -36,7 +36,7 @@ TORCH_API std::pair matchSchemas( Graph& graph, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = std::nullopt, + const std::optional& self = c10::nullopt, bool render_errors = false); TORCH_API bool convertibleToList( @@ -51,7 +51,7 @@ TORCH_API Value* emitBuiltinCall( Symbol name, at::ArrayRef args, at::ArrayRef kwargs, - const std::optional& self = std::nullopt); + const std::optional& self = c10::nullopt); TORCH_API std::optional findInputWithName( const std::string& name, diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index f7bc4a04cb6ce5..2adacb976a042f 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -155,7 +155,7 @@ std::optional SchemaTypeParser::parseAliasAnnotation() { Symbol::fromQualString("alias::$" + std::to_string(next_id++))); alias_info.setIsWrite(true); } else { - return std::nullopt; + return c10::nullopt; } return alias_info; @@ -172,7 +172,7 @@ std::optional SchemaTypeParser::parseTensorDType( if (type != type_map.end()) { return type->second; } - return std::nullopt; + return c10::nullopt; } std::optional SchemaTypeParser::tryToParseDeviceType() { @@ -297,7 +297,7 @@ TypePtr SchemaTypeParser::parseRefinedTensor() { // Parsing ranks, supports mix of sized and unsized ranks, or, just strided // ranks if (L.cur().kind == '*') { - dims.emplace_back(std::nullopt); + dims.emplace_back(c10::nullopt); L.next(); if (L.cur().kind == ':') { throw ErrorReport(L.cur()) << "Strides for unsized ranks not supported"; diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index db21737f4c4ba1..9295a3ed4007ab 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -137,10 +137,10 @@ std::optional> ScriptTypeParser::parseBroadcastList( } if (expr.kind() != TK_SUBSCRIPT) - return std::nullopt; + return c10::nullopt; auto subscript = Subscript(expr); if (subscript.value().kind() != TK_VAR) - return std::nullopt; + return c10::nullopt; auto var = Var(subscript.value()); auto subscript_exprs = subscript.subscript_exprs(); @@ -151,10 +151,10 @@ std::optional> ScriptTypeParser::parseBroadcastList( TypePtr opt_type = OptionalType::create(broadcast_list->first); return std::pair(opt_type, broadcast_list->second); } else { - return std::nullopt; + return c10::nullopt; } } else if (var.name().name().find("BroadcastingList") != 0) { - return std::nullopt; + return c10::nullopt; } if (subscript_exprs.size() != 1) @@ -352,7 +352,7 @@ std::vector ScriptTypeParser::evaluateDefaults( CompilationUnit cu; cu.define( - std::nullopt, + c10::nullopt, /*properties=*/{}, /*propResolvers=*/{}, {def}, @@ -407,7 +407,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( auto decl_arg = *it; TypePtr type; - std::optional N = std::nullopt; + std::optional N = c10::nullopt; if (!decl_arg.type().present()) { // If this param doesn't have a type, default to "tensor" type = TensorType::getInferred(); @@ -421,7 +421,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( type = parseTypeFromExpr(decl_arg.type().get()); } } - std::optional default_value = std::nullopt; + std::optional default_value = c10::nullopt; if (decl_arg.defaultValue().present()) { default_value = *defaults_it++; } @@ -431,7 +431,7 @@ std::vector ScriptTypeParser::parseArgsFromDecl( N, default_value, decl_arg.kwarg_only(), - /*alias_info=*/std::nullopt); + /*alias_info=*/c10::nullopt); retval.push_back(arg); } return retval; @@ -455,8 +455,8 @@ std::vector ScriptTypeParser::parseReturnFromDecl(const Decl& decl) { return {Argument( "", parsed_type, - /*N =*/std::nullopt, - /*default_value =*/std::nullopt, + /*N =*/c10::nullopt, + /*default_value =*/c10::nullopt, /*kwarg_only =*/false)}; } FunctionSchema ScriptTypeParser::parseSchemaFromDef( diff --git a/torch/csrc/jit/frontend/source_range.cpp b/torch/csrc/jit/frontend/source_range.cpp index b1dfecbbf6418c..20ffbfd4601e36 100644 --- a/torch/csrc/jit/frontend/source_range.cpp +++ b/torch/csrc/jit/frontend/source_range.cpp @@ -154,7 +154,7 @@ size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const { std::optional Source::findSourceRangeThatGenerated( const SourceRange& range) { if (!gen_ranges_) { - return std::nullopt; + return c10::nullopt; } return gen_ranges_->findSourceRangeThatGenerated(range); } diff --git a/torch/csrc/jit/frontend/source_range.h b/torch/csrc/jit/frontend/source_range.h index a8f22a800b022f..1f8715ad009691 100644 --- a/torch/csrc/jit/frontend/source_range.h +++ b/torch/csrc/jit/frontend/source_range.h @@ -1,6 +1,6 @@ #pragma once #include -#include +#include #include #include @@ -190,7 +190,7 @@ struct TORCH_API Source { explicit Source( c10::string_view text_view, - std::optional filename = std::nullopt, + std::optional filename = c10::nullopt, size_t starting_line_no = 0, std::shared_ptr gen_ranges = nullptr, CopiesString copies_str = COPIES_STRING) @@ -210,7 +210,7 @@ struct TORCH_API Source { explicit Source( StringCordView str, - std::optional filename = std::nullopt, + std::optional filename = c10::nullopt, size_t starting_line_no = 0, std::shared_ptr gen_ranges = nullptr) : text_view_(std::move(str)), @@ -360,7 +360,7 @@ struct TORCH_API SourceRange { std::optional> file_line_col() const { if (!source_view_ || !source()->filename()) { - return std::nullopt; + return c10::nullopt; } auto lineno = source_view_->lineno_for_offset(start_); @@ -383,7 +383,7 @@ struct TORCH_API SourceRange { std::optional findSourceRangeThatGenerated() const { if (!source_view_) { - return std::nullopt; + return c10::nullopt; } return source_view_->findSourceRangeThatGenerated(*this); } diff --git a/torch/csrc/jit/frontend/sugared_value.cpp b/torch/csrc/jit/frontend/sugared_value.cpp index 94a11b21b1f22c..4b65903529d23a 100644 --- a/torch/csrc/jit/frontend/sugared_value.cpp +++ b/torch/csrc/jit/frontend/sugared_value.cpp @@ -658,7 +658,7 @@ void IterableTree::addChild( // iterables run for the minimum length of all its leaves unroll_length_ = std::min(*child_len, *unroll_length_); } else { - unroll_length_ = std::nullopt; + unroll_length_ = c10::nullopt; } } children_.push_back(iter_value); diff --git a/torch/csrc/jit/frontend/sugared_value.h b/torch/csrc/jit/frontend/sugared_value.h index 1ca59ced6e68b8..97b092cad3ce7a 100644 --- a/torch/csrc/jit/frontend/sugared_value.h +++ b/torch/csrc/jit/frontend/sugared_value.h @@ -1,7 +1,7 @@ #pragma once +#include #include #include -#include #include #include @@ -122,13 +122,13 @@ struct TORCH_API SugaredValue // to support containers of Heterogenous types, like Module Containers & // Tuples virtual std::optional staticLen() { - return std::nullopt; + return c10::nullopt; } // When iterating over this SugaredValue, should we emit the for loop as an // unrolled loop. bool shouldEmitUnrolled() { - return staticLen() != std::nullopt; + return staticLen() != c10::nullopt; } // return length of this thing, if not then it can't be iterated. @@ -323,7 +323,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { } auto sym = Symbol::fromQualString(name + "::" + field); - return std::make_shared(sym, std::nullopt); + return std::make_shared(sym, c10::nullopt); } private: @@ -506,7 +506,7 @@ struct TORCH_API PrintValue : public SugaredValue { // is a noop when the input is a subtype of 'type' struct TORCH_API CastValue : public BuiltinFunction { CastValue(TypePtr type, c10::Symbol method) - : BuiltinFunction(method, std::nullopt), type_(std::move(type)) {} + : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {} std::shared_ptr call( const SourceRange& loc, GraphFunction& m, @@ -638,7 +638,7 @@ struct TORCH_API RangeValue : SugaredValue { const SourceRange& loc, GraphFunction& m, std::vector input, - std::optional static_len = std::nullopt); + std::optional static_len = c10::nullopt); std::string kind() const override { return "range"; @@ -730,7 +730,7 @@ struct TORCH_API IterableTree : SugaredValue { TypePtr type_hint = nullptr) override; private: - std::optional unroll_length_ = std::nullopt; + std::optional unroll_length_ = c10::nullopt; std::vector children_; }; diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index a90d5bb897f454..9616e0f83dfbe2 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -818,8 +818,8 @@ void addInputs(Node* n, const char* name, std::optional value) { n, name, value.has_value() - ? std::make_optional(value->guard_int(__FILE__, __LINE__)) - : std::nullopt); + ? c10::make_optional(value->guard_int(__FILE__, __LINE__)) + : c10::nullopt); } void addInputs( diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 6f674f30b90fca..f9b2ed5dd7ce9d 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -105,7 +105,7 @@ class MutableTypePtrHelper { } } if (mutable_types.empty()) { - return std::nullopt; + return c10::nullopt; } return mutable_types; } @@ -121,7 +121,7 @@ class MutableTypePtrHelper { return {AliasTypeSet{ FutureType::create(*toSingleType(*maybe_mut_types))}}; } - return std::nullopt; + return c10::nullopt; } case TypeKind::AwaitType: { if (auto maybe_mut_types = mapTypeToAliasTypeSet( @@ -129,7 +129,7 @@ class MutableTypePtrHelper { return { AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}}; } - return std::nullopt; + return c10::nullopt; } case TypeKind::TupleType: { std::vector mutable_types; @@ -142,12 +142,12 @@ class MutableTypePtrHelper { } } if (mutable_types.empty()) { - return std::nullopt; + return c10::nullopt; } return {AliasTypeSet{TupleType::create(mutable_types)}}; } default: - return std::nullopt; + return c10::nullopt; } } ska::flat_hash_map* mutable_type_cache_; @@ -1896,7 +1896,7 @@ bool AliasDb::mayAliasWildcard(const at::ArrayRef vs) const { std::optional AliasDb::tryGetOrCreateWildcard(const TypePtr& type) { auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type); if (!maybe_mut_types) { - return std::nullopt; + return c10::nullopt; } auto mut_type = toSingleType(*maybe_mut_types); auto existing_wildcard = wildcardIndex_.find(*mut_type); @@ -1970,7 +1970,7 @@ std::optional AliasDb::setWildcard(const Value* v) { std::optional maybe_wildcardElement = tryGetOrCreateWildcard(v->type()); if (!maybe_wildcardElement) { - return std::nullopt; + return c10::nullopt; } // Ensure that we create a corresponding Element for `v` still, as it is an // invariant that all mutable values have an Element diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index a0f8c8760a130b..ef697a5af76800 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -69,7 +69,7 @@ std::optional tryInsertConstant( at::Tensor ref = val.toTensor(); if (!insertableTensor(val.toTensor())) { n->destroy(); - return std::nullopt; + return c10::nullopt; } if (!ref.defined()) { n->destroy(); @@ -99,7 +99,7 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return std::nullopt; + return c10::nullopt; } } else if (val.isString()) { n->s_(attr::value, val.toStringRef()); @@ -125,7 +125,7 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return std::nullopt; + return c10::nullopt; }; } else if (val.isObject()) { const auto& ref = val.toObjectRef(); @@ -137,14 +137,14 @@ std::optional tryInsertConstant( n->output()->setType(val.type()); } else { n->destroy(); - return std::nullopt; + return c10::nullopt; } } else if ((val.isGenericDict() && insertableIValue(val)) || (val.isEnum())) { n->ival_(attr::value, val); n->output()->setType(val.type()); } else { n->destroy(); - return std::nullopt; + return c10::nullopt; } if (loc) n->setSourceRange(*loc); @@ -155,7 +155,7 @@ std::optional tryInsertConstant( std::optional toIValue(const Value* v) { if (v->node()->kind() != prim::Constant || v->type()->cast()) { - return std::nullopt; + return c10::nullopt; } const Node* node = v->node(); const TypePtr& type = v->type(); diff --git a/torch/csrc/jit/ir/constants.h b/torch/csrc/jit/ir/constants.h index 160dad5eab4c61..118da1e932d9c1 100644 --- a/torch/csrc/jit/ir/constants.h +++ b/torch/csrc/jit/ir/constants.h @@ -25,27 +25,27 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error { TORCH_API Value* insertConstant( Graph& g, const IValue& val, - std::optional loc = std::nullopt, - std::optional scope = std::nullopt); + std::optional loc = c10::nullopt, + std::optional scope = c10::nullopt); // note: prefer g.insertConsant(val, loc) which does exactly the same thing // this function is only declared/defined here because its implementation is // closely related to the implementation of prim::Constant that is also in // constants.cpp. // -// returns a std::nullopt if the IValue kind cannot be inserted as a constant +// returns a c10::nullopt if the IValue kind cannot be inserted as a constant TORCH_API std::optional tryInsertConstant( Graph& g, const IValue& val, - std::optional loc = std::nullopt, - std::optional scope = std::nullopt); + std::optional loc = c10::nullopt, + std::optional scope = c10::nullopt); //////////////////////////////////////////////////////////////////////////////// // Helper for retrieving constants //////////////////////////////////////////////////////////////////////////////// // attempt to convert a (possibly constant) Value* into an interpreter value -// (IValue). returns std::nullopt if the Value* was not constant +// (IValue). returns c10::nullopt if the Value* was not constant TORCH_API std::optional toIValue(const Value* v); // if a value is a constant then try to turn into type T using the @@ -55,7 +55,7 @@ std::optional constant_as(const Value* v) { if (auto ivalue = toIValue(v)) { return ivalue->to(); } - return std::nullopt; + return c10::nullopt; } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 3b449ea7ea21f7..a6b0116d7fb63f 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -412,7 +412,7 @@ std::ostream& operator<<(std::ostream& out, const Graph& g) { static void checkSameDevice(const Node* node) { bool has_device = false; - std::optional device = std::nullopt; + std::optional device = c10::nullopt; auto checkValue = [&](const Value* v) { if (TensorTypePtr type = v->type()->cast()) { if (type->device() && !has_device) { @@ -1297,7 +1297,7 @@ Node::Node(Graph* graph_, NodeKind kind_) graph_(graph_), owning_block_(nullptr), scope_(graph_->current_scope_), - callstack_(std::nullopt), + callstack_(c10::nullopt), op_(nullptr), topo_position_(0) { graph_->all_nodes.emplace(this); @@ -2101,11 +2101,11 @@ std::vector inlineCallTo( std::unordered_map new_callstack_entries; - std::optional module_instance_info = std::nullopt; + std::optional module_instance_info = c10::nullopt; if (to_replace->kind() == prim::CallMethod) { auto class_type_ptr = to_replace->input(0)->type()->cast(); if (to_replace->input(0)->node()->kind() == prim::GetAttr) { - module_instance_info = std::make_optional(ModuleInstanceInfo( + module_instance_info = c10::make_optional(ModuleInstanceInfo( class_type_ptr, to_replace->input(0)->node()->s(attr::name))); } else if ( !to_replace->owningGraph()->inputs().empty() && @@ -2113,11 +2113,11 @@ std::vector inlineCallTo( // This CallMethod must correspond to method of the same object // to which this graph belongs. module_instance_info = - std::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF")); + c10::make_optional(ModuleInstanceInfo(class_type_ptr, "SELF")); } else { // Not sure if it is possible to come here ever. // TODO: Remove this else. Or add assert - module_instance_info = std::make_optional( + module_instance_info = c10::make_optional( ModuleInstanceInfo(class_type_ptr, "INSTANCE_NAME_UNKNOWN")); } } diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 3db67b2f9798ce..859da3cb3cae99 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -348,7 +348,7 @@ struct TORCH_API Node { // is changed, we need to rely on this name // to retrieve old schemas to successfully apply upgraders // for this operator. - std::optional historic_schema_name_ = std::nullopt; + std::optional historic_schema_name_ = c10::nullopt; protected: Node(Graph* graph_, NodeKind kind_); // defined after graph @@ -534,7 +534,7 @@ struct TORCH_API Node { if (auto v = get(name)) { return v->template to(); } - return std::nullopt; + return c10::nullopt; } // Returns true if the value of input name is statically known @@ -1368,8 +1368,8 @@ struct Graph : std::enable_shared_from_this { // Insert constant IValue into the graph. TORCH_API Value* insertConstant( const IValue& val, - std::optional loc = std::nullopt, - std::optional scope = std::nullopt); + std::optional loc = c10::nullopt, + std::optional scope = c10::nullopt); // Schema-driven insert: // This inserts a node into the graph with inputs determined from args and @@ -1733,14 +1733,14 @@ struct OperatorMap { std::optional find(const Operator& op) { const auto it = map.find(Symbol::fromQualString(op.schema().name())); if (it == map.end()) { - return std::nullopt; + return c10::nullopt; } for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { if (vit->first->schema() == op.schema()) { return vit->second; } } - return std::nullopt; + return c10::nullopt; } // TODO: return iterator @@ -1809,14 +1809,14 @@ struct FunctionSchemaMap { std::optional find(const FunctionSchema& schema) const { const auto it = map.find(Symbol::fromQualString(schema.name())); if (it == map.end()) { - return std::nullopt; + return c10::nullopt; } for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { if (vit->first == schema) { return vit->second; } } - return std::nullopt; + return c10::nullopt; } // TODO: return iterator diff --git a/torch/csrc/jit/ir/scope.h b/torch/csrc/jit/ir/scope.h index 5eafeb0fc4aac2..54498030322384 100644 --- a/torch/csrc/jit/ir/scope.h +++ b/torch/csrc/jit/ir/scope.h @@ -1,10 +1,10 @@ #pragma once #include #include +#include #include #include #include -#include #include namespace torch { diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 3b8319ad8f90ee..f0dd562cc1cd27 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -408,7 +408,7 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { } // Loading the TS module is required for this backport, because bytecode needs // to be re-emitted (refer to the comments below) - Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); // The RAII guard to change the flag, emitBytecodeDefaultInputs, to true, so // that TS stores the default argument values in the constant table, and emits @@ -476,7 +476,7 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { } // Loading the TS module is required for this backport, because bytecode needs // to be re-emitted (refer to the comments below) - Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); // The RAII guard to change the flag, emit_default_input_instructions, to // false to keep the same behavior in bytecode version 6. Change the flag, @@ -502,7 +502,7 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) { ExtraFilesMap extra_files; Module torch_script = - torch::jit::load(input_model_stream, std::nullopt, extra_files); + torch::jit::load(input_model_stream, c10::nullopt, extra_files); std::stringstream intermediate_model_stream; // TODO(@pavithran) : Check if debug info is available and use load/save while // backporting hardcode debaug info to be false untill supported. @@ -540,7 +540,7 @@ std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) { extra_files.emplace(record.substr(found + 1), ""); } } - Module torch_script = torch::jit::load(rai, std::nullopt, extra_files); + Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); std::stringstream intermediate_model_stream; { BytecodeEmitModeGuard argNumGuard( diff --git a/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h b/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h index d89165bb1d2950..2e65f1f38bd8d2 100644 --- a/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h +++ b/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 2094d4a87a1719..bca40735891394 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include #include #include -#include #ifndef DISABLE_UPGRADER #include @@ -364,7 +364,7 @@ std::unique_ptr FlatbufferLoader::parseFunction( (operator_version < caffe2::serialize::kProducedFileFormatVersion); for (const auto* op : *method->operators()) { - std::optional num_args = std::nullopt; + std::optional num_args = c10::nullopt; if (op->num_args_serialized() > -1) { num_args = op->num_args_serialized(); } @@ -399,7 +399,7 @@ std::unique_ptr FlatbufferLoader::parseFunction( auto arg = c10::Argument( arg_tb->name()->str(), std::move(type_ptr), - std::nullopt /*N*/, + c10::nullopt /*N*/, std::move(default_value)); args.emplace_back(std::move(arg)); } diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index 62b2c795bf84d4..9ac9636f3f14be 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include /** * Defines the public API for loading flatbuffer-serialized mobile modules. @@ -58,7 +58,7 @@ using ExtraFilesMap = std::unordered_map; TORCH_API mobile::Module parse_and_initialize_mobile_module( void* data, size_t size, // of `data`, in bytes. - std::optional device = std::nullopt, + std::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr, bool should_copy_tensor_memory = false); @@ -74,7 +74,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module( TORCH_API mobile::Module parse_and_initialize_mobile_module( std::shared_ptr data, size_t size, // of `data`, in bytes. - std::optional device = std::nullopt, + std::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr); // Parse a mobile::Module from raw bytes, also returning JIT-related metadata. @@ -87,7 +87,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( size_t size, // of `data`, in bytes. ExtraFilesMap& jit_sources, std::vector& jit_constants, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr); // Load a mobile::Module from a filepath. @@ -100,7 +100,7 @@ TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( // directly. TORCH_API mobile::Module load_mobile_module_from_file( const std::string& filename, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr); TORCH_API uint64_t get_bytecode_version(std::istream& in); @@ -114,7 +114,7 @@ TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( // its entirity to a buffer TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( std::istream& in, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr); TORCH_API mobile::Module parse_flatbuffer_no_object( diff --git a/torch/csrc/jit/mobile/frame.h b/torch/csrc/jit/mobile/frame.h index 4ad3817af624ec..45c51fef0085e5 100644 --- a/torch/csrc/jit/mobile/frame.h +++ b/torch/csrc/jit/mobile/frame.h @@ -2,8 +2,8 @@ #include +#include #include -#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 9c3626e361da48..36f19fb1fac41c 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -72,7 +72,7 @@ bool Function::initialize_operators(bool should_check_operators) { const auto& opname = code_.op_names_[i]; int num_args = code_.operator_input_sizes_[i]; std::optional num_specified_args = - num_args < 0 ? std::nullopt : std::optional(num_args); + num_args < 0 ? c10::nullopt : std::optional(num_args); auto func = makeOperatorFunction(opname, num_specified_args); if (!func.has_value()) { unsupported_op_names.insert(operator_str(opname)); @@ -189,7 +189,7 @@ std::optional> makeOperatorFunction( TORCH_CHECK(false, "arguments are missing for operator ", opname); } } else { - return std::nullopt; + return c10::nullopt; } } } diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 1fa2fe47904b56..da7b87bae6110d 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -22,7 +23,6 @@ #include #include #include -#include #include #include @@ -267,7 +267,7 @@ void BytecodeDeserializer::parseFunctionSchema( args.emplace_back( name, std::move(type), - std::nullopt /*N*/, + c10::nullopt /*N*/, std::move(default_value)); } tryRegisterMethod(args, *function); @@ -704,7 +704,7 @@ void _load_extra_only_for_mobile( // TODO: the current flatbuffers implementation will always load the // whole module including the extra files. Ideally it should be // possible to just get the extra files given data - load_mobile_module_from_file(filename, std::nullopt, &extra_files); + load_mobile_module_from_file(filename, c10::nullopt, &extra_files); break; } default: { diff --git a/torch/csrc/jit/mobile/import.h b/torch/csrc/jit/mobile/import.h index 73ebe18976d60c..77a801e62571df 100644 --- a/torch/csrc/jit/mobile/import.h +++ b/torch/csrc/jit/mobile/import.h @@ -45,15 +45,15 @@ TORCH_API mobile::Module _load_for_mobile( TORCH_API mobile::Module _load_for_mobile( std::istream& in, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); TORCH_API mobile::Module _load_for_mobile( const std::string& filename, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); TORCH_API mobile::Module _load_for_mobile( std::unique_ptr rai, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); /** * Load only the contents of the "extra/" files whose names are diff --git a/torch/csrc/jit/mobile/import_data.h b/torch/csrc/jit/mobile/import_data.h index d2d2fa7f998e22..25e1fd81341c18 100644 --- a/torch/csrc/jit/mobile/import_data.h +++ b/torch/csrc/jit/mobile/import_data.h @@ -2,8 +2,8 @@ #include #include +#include #include -#include #include #include @@ -19,7 +19,7 @@ namespace jit { */ TORCH_API std::map _load_parameters( std::istream& in, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); /** * Loads named parameters from the serialized data in @p filename. @@ -28,7 +28,7 @@ TORCH_API std::map _load_parameters( */ TORCH_API std::map _load_parameters( const std::string& filename, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); // NOTE: Please prefer using _load_parameters over using the function below. TORCH_API std::map mobile_module_to_parameter_map( diff --git a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h index 813d7be7e7a2a9..b6abe86c0fdca7 100644 --- a/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h +++ b/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h @@ -104,7 +104,7 @@ class MobileModelRunner { */ bool has_new_style_bundled_inputs() const { return module_->find_method("get_bundled_inputs_functions_and_info") != - std::nullopt; + c10::nullopt; } /** diff --git a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp index 3687f84f703971..585747c14d8210 100644 --- a/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp +++ b/torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp @@ -117,10 +117,10 @@ void call_dependent_methods(std::set& root_ops) { if (is_training && has_batchnorm) { at::batch_norm( at::ones({2, 2}), - std::nullopt, - std::nullopt, - std::nullopt, - std::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, + c10::nullopt, true, 0.1, 0.1, diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index bcf4e5e1f6ba7c..23dfe9ff367852 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -90,10 +90,10 @@ void Module::unsafeCopyMethod( std::optional Module::find_method(const std::string& basename) const { for (const auto& fn : cu_->methods()) { if (fn->name() == basename) { - return std::make_optional(Method(this, fn.get())); + return c10::make_optional(Method(this, fn.get())); } } - return std::nullopt; + return c10::nullopt; } namespace { @@ -324,7 +324,7 @@ static std::optional print_type(const c10::Type& t) { if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return std::nullopt; + return c10::nullopt; } TORCH_API ModuleInfo get_module_info(const mobile::Module& module) { diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 1d9d6fb3abcfaf..8e49749042424c 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -118,7 +118,7 @@ void toPrimDType(Stack& stack) { pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); - std::optional device = std::nullopt; + std::optional device = c10::nullopt; at::Tensor self = pop(stack).toTensor(); push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); } diff --git a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp index 21889a84b44070..3185b0eaf123ca 100644 --- a/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp +++ b/torch/csrc/jit/operator_upgraders/upgraders_entry.cpp @@ -122,7 +122,7 @@ std::shared_ptr create_upgrader_graph( const std::string& upgrader_name, const std::string& upgrader_body) { auto cu = std::make_shared(); - cu->define(std::nullopt, upgrader_body, nativeResolver(), nullptr); + cu->define(c10::nullopt, upgrader_body, nativeResolver(), nullptr); Function& jitFunc = cu->get_function(upgrader_name); GraphFunction& graphFunction = toGraphFunction(jitFunc); return graphFunction.graph(); diff --git a/torch/csrc/jit/operator_upgraders/utils.cpp b/torch/csrc/jit/operator_upgraders/utils.cpp index 98819b08d640b4..fef7b92c83c95a 100644 --- a/torch/csrc/jit/operator_upgraders/utils.cpp +++ b/torch/csrc/jit/operator_upgraders/utils.cpp @@ -1,9 +1,9 @@ #include +#include #include #include #include -#include #include #include #include @@ -27,7 +27,7 @@ std::optional findUpgrader( if (pos != upgraders_for_schema.end()) { return *pos; } - return std::nullopt; + return c10::nullopt; } bool isOpCurrentBasedOnUpgraderEntries( diff --git a/torch/csrc/jit/operator_upgraders/utils.h b/torch/csrc/jit/operator_upgraders/utils.h index 95e794261e6b97..a30b8c1182b9cf 100644 --- a/torch/csrc/jit/operator_upgraders/utils.h +++ b/torch/csrc/jit/operator_upgraders/utils.h @@ -1,8 +1,8 @@ #pragma once #include +#include #include #include -#include #include #include diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index bbd56744afb7d8..635162e0495319 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -4,10 +4,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -65,7 +65,7 @@ std::optional parseAutocast( const AutocastContext& context) { if (!isAutocastNode(value)) { // Not an autocast... - return std::nullopt; + return c10::nullopt; } if (value->node()->kind() == prim::CreateObject) { AutocastScope scope; @@ -135,7 +135,7 @@ std::optional parseAutocast( AT_ERROR("Unsupported autocast syntax"); } - return std::nullopt; + return c10::nullopt; } void castTensorInputs( @@ -269,7 +269,7 @@ void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) { void handleBlock(Block* block, AutocastContext initial_state) { std::stack autocast_stack; - std::optional incompatible_amp = std::nullopt; + std::optional incompatible_amp = c10::nullopt; // The current autocast enabled/disabled state auto current_state = [&] { diff --git a/torch/csrc/jit/passes/canonicalize.cpp b/torch/csrc/jit/passes/canonicalize.cpp index 2aa6aff76bc1d7..20a883a8d06fdd 100644 --- a/torch/csrc/jit/passes/canonicalize.cpp +++ b/torch/csrc/jit/passes/canonicalize.cpp @@ -144,7 +144,7 @@ bool isBeforeOrAfter(const Use& a, const Use& b, bool checking_before) { std::optional firstOrLastUse(Value* v, bool find_first) { if (v->uses().empty()) { - return std::nullopt; + return c10::nullopt; } Use extreme_use = v->uses()[0]; for (size_t i = 1; i < v->uses().size(); ++i) { @@ -176,12 +176,12 @@ static std::vector sort_indexes(at::ArrayRef values) { // if neither has any uses, use original ordering. Since the // only values that jitter are ones added by the compiler and are guaranteed // to have uses, original ordering is fine. - if (first_uses[i1] == std::nullopt && first_uses[i2] == std::nullopt) { + if (first_uses[i1] == c10::nullopt && first_uses[i2] == c10::nullopt) { return i1 < i2; } - if (first_uses[i1] == std::nullopt) { + if (first_uses[i1] == c10::nullopt) { return false; - } else if (first_uses[i2] == std::nullopt) { + } else if (first_uses[i2] == c10::nullopt) { return true; } diff --git a/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp b/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp index b3e190445b8fe3..72d419eeb9c163 100644 --- a/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp +++ b/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.cpp @@ -26,14 +26,14 @@ static std::optional> getChunkOutputs(Node* chunk) { // number of chunks if (static_cast(list_use.user->outputs().size()) != chunk->get(attr::chunks).value()) { - return std::nullopt; + return c10::nullopt; } auto unpack_outputs = list_use.user->outputs(); for (const auto i : c10::irange(unpack_outputs.size())) { outputs.emplace_back(unpack_outputs[i], i); } } else { - return std::nullopt; + return c10::nullopt; } } return outputs; diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 5ec8b561cba80a..6334cd75faa903 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -28,14 +28,14 @@ std::optional> runNodeIfInputsAreConstant( if (auto ival = toIValue(input)) { stack.push_back(*ival); } else { - return std::nullopt; + return c10::nullopt; } } switch (n->kind()) { case prim::ListUnpack: { if (stack.back().toList().size() != n->outputs().size()) { - return std::nullopt; + return c10::nullopt; } listUnpack(stack, n->outputs().size()); } break; @@ -78,14 +78,14 @@ std::optional> runNodeIfInputsAreConstant( // vararg schemas require the number of inputs at the top of the stack // but this is broken in other places in constant prop, so disable it // for now - return std::nullopt; + return c10::nullopt; } try { auto op = n->getOperation(); op(stack); } catch (...) { - return std::nullopt; + return c10::nullopt; } } break; } @@ -95,13 +95,13 @@ std::optional> runNodeIfInputsAreConstant( const at::Tensor& t = v.toTensor(); if (t.defined() && t.requires_grad()) { // requires grad tensors cannot be constants - return std::nullopt; + return c10::nullopt; } } // Weak form of const propagation if (ignore_custom_classes) { if (v.isCustomClass()) { - return std::nullopt; + return c10::nullopt; } } // see [Constant Object Weak CompilationUnit Reference] @@ -123,7 +123,7 @@ std::optional> runNodeIfInputsAreConstant( } if (v.isObject()) { if (!v.toObject()->is_weak_compilation_ref()) { - return std::nullopt; + return c10::nullopt; } } } diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 46eca6f2b221f5..c5fe65537669a8 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -287,7 +287,7 @@ class SubgraphSlicer { aliasDb_.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { - return std::nullopt; + return c10::nullopt; } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( @@ -305,11 +305,11 @@ class SubgraphSlicer { std::optional getProfileNodeRequiresGrad(Node* n) { TORCH_INTERNAL_ASSERT(n->kind() == prim::profile); if (!n->hasAttribute(attr::profiled_type)) { - return std::nullopt; + return c10::nullopt; } auto& type = n->ty(attr::profiled_type); if (type->castRaw() == nullptr) { - return std::nullopt; + return c10::nullopt; } return type->expectRef().requiresGrad(); } @@ -403,7 +403,7 @@ std::optional findRequiresGradForOutput( } } - return std::nullopt; + return c10::nullopt; } void AddRequiresGradToDifferentiableGraph( diff --git a/torch/csrc/jit/passes/device_type_analysis.cpp b/torch/csrc/jit/passes/device_type_analysis.cpp index c9c9188d37dc5e..7670292696ae69 100644 --- a/torch/csrc/jit/passes/device_type_analysis.cpp +++ b/torch/csrc/jit/passes/device_type_analysis.cpp @@ -2,12 +2,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include namespace torch { @@ -88,7 +88,7 @@ bool propWithNoDevice(Node* n) { } if (input_num == n->inputs().size()) { // No tensor found - return setReturnsToDevice(n, std::nullopt); + return setReturnsToDevice(n, c10::nullopt); } auto tensor_type = n->inputs()[input_num]->type()->expect(); @@ -108,7 +108,7 @@ bool propWithNoDevice(Node* n) { only_seen_cpu_zerodim = false; } else { // Bail on the type not match case - return setReturnsToDevice(n, std::nullopt); + return setReturnsToDevice(n, c10::nullopt); } } } diff --git a/torch/csrc/jit/passes/dtype_analysis.cpp b/torch/csrc/jit/passes/dtype_analysis.cpp index 2311cb791a449c..f63ea6f3419489 100644 --- a/torch/csrc/jit/passes/dtype_analysis.cpp +++ b/torch/csrc/jit/passes/dtype_analysis.cpp @@ -3,13 +3,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -102,7 +102,7 @@ static bool canBeInferredWithMetaTensor(Node* n) { std::optional inferWithMetaTensor(Node* n) { GRAPH_DEBUG("inferWithMetaTensor", getHeader(n)); if (!canBeInferredWithMetaTensor(n)) { - return std::nullopt; + return c10::nullopt; } Operation op = n->getOperation(); try { @@ -116,7 +116,7 @@ std::optional inferWithMetaTensor(Node* n) { } catch (...) { GRAPH_DEBUG("caught exception with Metatensor run!"); }; - return std::nullopt; + return c10::nullopt; } bool setDtype( diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index ccafee9aa4ae43..540f1a7e13fb84 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -41,7 +41,7 @@ void EraseNumberTypesOnBlock(Block* block) { WithInsertPoint guard(*it); Value* r = block->owningGraph()->insertConstant( - scalar_to_tensor(s), std::nullopt, it->scope()); + scalar_to_tensor(s), c10::nullopt, it->scope()); r->copyMetadata(it->output()); it->output()->replaceAllUsesWith(r); it.destroyCurrent(); diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 23bc873addc714..4d67d5d2178134 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -170,7 +170,7 @@ class AttributePropagator { std::optional resolveName(const std::string& name) { auto sub_names = splitName(name); if (sub_names.empty()) { - return std::nullopt; + return c10::nullopt; } auto& attr_name = sub_names.back(); auto cur_module = module_; @@ -189,7 +189,7 @@ class AttributePropagator { } } if (!found) { - return std::nullopt; + return c10::nullopt; } } @@ -207,7 +207,7 @@ class AttributePropagator { return std::make_pair(std::move(cur_module), std::move(attr_name)); } - return std::nullopt; + return c10::nullopt; } bool _loadModulePath(Value* input, std::shared_ptr& graph) { @@ -230,7 +230,7 @@ class AttributePropagator { std::shared_ptr& graph) { bool success = _loadModulePath(input, graph); if (!success) { - return std::nullopt; + return c10::nullopt; } return names_; } diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp index b508cd905c586b..c28e99a445258a 100644 --- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp +++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp @@ -1105,7 +1105,7 @@ class MKLDNNSubgraphSlicer { aliasDb_.moveAfterTopologicallyValid(consumer, producer); if (!canMerge) { - return std::nullopt; + return c10::nullopt; } SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 5136615cd2e441..98487830726216 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -494,7 +494,7 @@ struct GraphFuser { AT_ASSERT(group->kind() == prim::FusionGroup); auto it = std::find(group->inputs().begin(), group->inputs().end(), input); if (it == group->inputs().end()) { - return std::nullopt; + return c10::nullopt; } size_t input_index = it - group->inputs().begin(); auto& subgraph = getSubgraph(group); @@ -505,7 +505,7 @@ struct GraphFuser { AT_ASSERT(subgraph_input->uses().size() == 1); return node; } - return std::nullopt; + return c10::nullopt; } void fuseChunkByReusingExistingFusedChunk( diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp index 430dbb3fd1c851..edb9f5b9589a06 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -287,7 +287,7 @@ bool isClampFusable( vmap.find("output_max") != vmap.end(), "Expected to find output_max as well given " "output_min exist in pattern graph."); - // If output_min/max are not constant, we get std::nullopt. + // If output_min/max are not constant, we get c10::nullopt. auto output_min = graph_rewrite_helper::getIValue("output_min", match_vmap, vmap); auto output_max = diff --git a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp index 226826e946098a..f8d63e87f07b7e 100644 --- a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp @@ -68,7 +68,7 @@ graph_node_list::iterator scanNode(Node* node, size_t threshold) { // so the profiles will have outdated requires_grad=False. // conservatively update them to maybe requiring grad, bc we might create // autodiff graphs when the tensors maybe require grad - UpdateDifferentiableGraphRequiresGrad(subgraph, std::nullopt); + UpdateDifferentiableGraphRequiresGrad(subgraph, c10::nullopt); SubgraphUtils::unmergeSubgraph(node); return next_node; } diff --git a/torch/csrc/jit/passes/integer_value_refinement.cpp b/torch/csrc/jit/passes/integer_value_refinement.cpp index cf9b577f927b28..16a329b3b11f34 100644 --- a/torch/csrc/jit/passes/integer_value_refinement.cpp +++ b/torch/csrc/jit/passes/integer_value_refinement.cpp @@ -93,7 +93,7 @@ struct IntegerValueRefiner { auto other_output = other_if_block->outputs().at(i); auto other_const_value = other_output->type()->cast() ? constant_as(other_output) - : std::nullopt; + : c10::nullopt; if (!other_const_value || block_output->node()->kind() == prim::Constant) { continue; @@ -211,7 +211,7 @@ struct IntegerValueRefiner { return maybe_refinement->second; } } - return std::nullopt; + return c10::nullopt; } std::shared_ptr graph_; diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 61d97057c5b429..4eeba79aae90c4 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -5,9 +5,9 @@ #include #include +#include #include #include -#include namespace torch { namespace jit { @@ -72,15 +72,15 @@ std::optional runTorchSlice_opset9( TORCH_WARN( "Constant folding - Invalid number of inputs found for opset 9 " "onnx::Slice op. Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { - return std::nullopt; + return c10::nullopt; } auto startsAttr = node->is(attr::starts); auto endsAttr = node->is(attr::ends); if (startsAttr.size() != endsAttr.size()) { - return std::nullopt; + return c10::nullopt; } std::vector axesAttr; if (node->hasAttributeS("axes")) { @@ -98,7 +98,7 @@ std::optional runTorchSlice_opset9( handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); int64_t length = end - start; if (length < 0 || start > updated_val.sizes()[axis] - length) - return std::nullopt; + return c10::nullopt; updated_val = at::narrow(updated_val, axis, start, length); } return std::optional(updated_val); @@ -114,7 +114,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid number of inputs found for opset opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } // Checking validity of 'starts' and 'ends' input if (inputTensorValues[1].sizes().size() != 1 || @@ -122,12 +122,12 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'starts' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } if (inputTensorValues[1].sizes()[0] != inputTensorValues[2].sizes()[0]) { // Number of elements of 'starts' and 'ends' 1-D input tensors should be the // same - return std::nullopt; + return c10::nullopt; } // Checking 'axes' input, if available. std::vector axes; @@ -136,7 +136,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'axes' input found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } if (inputTensorValues[3].sizes()[0] != inputTensorValues[1].sizes()[0]) { // Number of elements of 'axes' and 'ends' 1-D input tensors should be the @@ -144,7 +144,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'axes' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } auto axes_a = inputTensorValues[3].accessor(); axes.resize(inputTensorValues[3].sizes()[0]); @@ -162,7 +162,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'steps' input found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } if (inputTensorValues[4].sizes()[0] != inputTensorValues[1].sizes()[0]) { // Number of elements of 'steps' and 'ends' 1-D input tensors should be @@ -170,7 +170,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Invalid 'steps' or 'ends' inputs found for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } auto steps_a = inputTensorValues[4].accessor(); for (const auto i : c10::irange(inputTensorValues[4].sizes()[0])) { @@ -179,7 +179,7 @@ std::optional runTorchSlice_opset10( TORCH_WARN( "Constant folding - Only steps=1 can be constant folded for opset >= 10 onnx::Slice op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } } } @@ -192,7 +192,7 @@ std::optional runTorchSlice_opset10( handleNegativeStartEndIndex(start, end, axis, updated_val.sizes()); int64_t length = end - start; if (length < 0 || start > updated_val.sizes()[axis] - length) - return std::nullopt; + return c10::nullopt; updated_val = at::narrow(updated_val, axis, start, length); } return std::optional(updated_val); @@ -272,11 +272,11 @@ std::optional runTorchBackendForOnnx( } else { TORCH_WARN( "Constant folding - unsupported opset version. Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } } else if (node->kind() == onnx::Concat) { if (!node->hasAttributeS("axis")) { - return std::nullopt; + return c10::nullopt; } updated_val = at::cat(at::TensorList(inputTensorValues), node->i(attr::axis)); @@ -310,7 +310,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Unsqueeze op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } auto axes_a = inputTensorValues[1].accessor(); std::vector axes; @@ -332,7 +332,7 @@ std::optional runTorchBackendForOnnx( } else if (opset_version >= ONNX_OPSET_9) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { - return std::nullopt; + return c10::nullopt; } updated_val = inputTensorValues[0]; std::vector axesAttr = node->is(attr::axes); @@ -345,7 +345,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - unsupported opset version. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } } else if (node->kind() == onnx::Squeeze) { assert(inputTensorValues.size() == 2 || inputTensorValues.size() == 1); @@ -359,7 +359,7 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - Invalid 'axes' inputs found for opset 13 onnx::Squeeze op. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } auto axes_a = inputTensorValues[1].accessor(); std::vector axes; @@ -389,12 +389,12 @@ std::optional runTorchBackendForOnnx( TORCH_WARN( "Constant folding - unsupported opset version. " "Constant folding not applied."); - return std::nullopt; + return c10::nullopt; } } else if (node->kind() == onnx::Transpose) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("perm")) { - return std::nullopt; + return c10::nullopt; } updated_val = inputTensorValues[0].permute(node->is(attr::perm)); return std::optional(updated_val); @@ -405,7 +405,7 @@ std::optional runTorchBackendForOnnx( ONNXTypeToATenType(node->i(attr::to)).value()); return std::optional(updated_val); } - return std::nullopt; + return c10::nullopt; } else if (node->kind() == onnx::Reshape) { assert(inputTensorValues.size() == 2); updated_val = inputTensorValues[0]; @@ -441,10 +441,10 @@ std::optional runTorchBackendForOnnx( } else if (node->kind() == onnx::ReduceL1 || node->kind() == onnx::ReduceL2) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { - return std::nullopt; + return c10::nullopt; } if (!node->hasAttributeS("keepdims")) { - return std::nullopt; + return c10::nullopt; } int p = node->kind() == onnx::ReduceL1 ? 1 : 2; updated_val = at::norm( @@ -485,7 +485,7 @@ std::optional runTorchBackendForOnnx( // at::index_select only supports indices with rank <= 1. // See https://pytorch.org/docs/main/generated/torch.index_select.html if (q > 1) { - return std::nullopt; + return c10::nullopt; } // If the device of indices tensor is not the same with it of the input // tensor, move it to the device of the input tensor @@ -539,7 +539,7 @@ std::optional runTorchBackendForOnnx( updated_val = at::softmax(inputTensorValues[0], axis); return std::optional(updated_val); } else { - return std::nullopt; + return c10::nullopt; } } @@ -652,7 +652,7 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict, int opset_version) { } auto updatedValWrapped = onnx_constant_fold::runTorchBackendForOnnx( node, inputTensorValues, opset_version); - if (updatedValWrapped == std::nullopt) { + if (updatedValWrapped == c10::nullopt) { // Constant folding is not supported for this op. Skip it. continue; } diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index d25ebee32a787e..201c3def32685a 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -2,8 +2,8 @@ #include +#include #include -#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 99c801dcf77367..f9c96d0430df02 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -34,14 +34,14 @@ bool ConstantValueMap::HasRank(const std::string& tensorName) { std::optional ConstantValueMap::GetRank(const std::string& tensorName) { if (!HasRank(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().rankMap[tensorName]; } void ConstantValueMap::SetAllGraphInputsStatic(bool all_static) { ConstantValueMap::getInstance().allGraphInputsStatic = - std::make_optional(all_static); + c10::make_optional(all_static); } std::optional ConstantValueMap::GetAllGraphInputsStatic() { @@ -71,7 +71,7 @@ bool ConstantValueMap::HasShape(const std::string& tensorName) { std::optional ConstantValueMap::GetShape( const std::string& tensorName) { if (!HasShape(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().shapeMap[tensorName]; } @@ -90,7 +90,7 @@ bool ConstantValueMap::HasValue(const std::string& tensorName) { std::optional ConstantValueMap::GetValue( const std::string& tensorName) { if (!HasValue(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().tensorValueMap[tensorName]; } @@ -121,7 +121,7 @@ std::optional> ConstantValueMap::GetShapeInto1DInt64Vector( return shape_value; } } - return std::nullopt; + return c10::nullopt; } std::optional> ConstantValueMap:: @@ -152,7 +152,7 @@ std::optional> ConstantValueMap:: } } } - return std::nullopt; + return c10::nullopt; } // accessor for 1DInt64 case. @@ -183,7 +183,7 @@ bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) { std::optional ConstantValueMap::GetTypeReliable( const std::string& tensorName) { if (!HasTypeReliable(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().typeReliableMap[tensorName]; } @@ -202,7 +202,7 @@ bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) { std::optional ConstantValueMap::GetUseInferredType( const std::string& tensorName) { if (!HasUseInferredType(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().useInferredTypeMap[tensorName]; } @@ -221,7 +221,7 @@ bool ConstantValueMap::HasShapeValue(const std::string& tensorName) { std::optional ConstantValueMap::GetShapeValue( const std::string& tensorName) { if (!HasShapeValue(tensorName)) { - return std::nullopt; + return c10::nullopt; } return ConstantValueMap::getInstance().shapeValueMap[tensorName]; } @@ -284,7 +284,7 @@ void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().inferredShapeData.clear(); ConstantValueMap::getInstance().symbolDimMap.clear(); ConstantValueMap::getInstance().dimSymbolMap.clear(); - ConstantValueMap::getInstance().allGraphInputsStatic = std::nullopt; + ConstantValueMap::getInstance().allGraphInputsStatic = c10::nullopt; ConstantValueMap::getInstance().allGraphInputsReliableComputed = false; } diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index febf412e5d1224..c545c7aba823a1 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -225,16 +225,16 @@ std::optional FunctionExtractor::FunctionContext::FindAttrName( auto v_it = scope_ctxs_[scope_key_]->env_to_subgraph_.find(ref_n->outputs().at(0)); if (v_it == scope_ctxs_[scope_key_]->env_to_subgraph_.end()) { - return std::nullopt; + return c10::nullopt; } auto* n_in_def = v_it->second->node(); auto n_attr_it = node_attr_to_name_.find(n_in_def); if (n_attr_it == node_attr_to_name_.end()) { - return std::nullopt; + return c10::nullopt; } auto name_it = n_attr_it->second.find(attr.toUnqualString()); if (name_it == n_attr_it->second.end()) { - return std::nullopt; + return c10::nullopt; } return name_it->second; } @@ -301,7 +301,7 @@ std::optional FunctionExtractor::FindCommonAncestor( ScopePtr a, ScopePtr b) { if (!IsValidScope(a) || !IsValidScope(b)) { - return std::nullopt; + return c10::nullopt; } auto diff = @@ -327,20 +327,20 @@ std::optional FunctionExtractor::FindCommonAncestor( } } - return std::nullopt; + return c10::nullopt; } std::optional FunctionExtractor::FindCommonAncestor( const scope_list& scopes) { if (scopes.empty()) { - return std::nullopt; + return c10::nullopt; } std::optional common_ancestor = scopes.at(0); for (const auto& scope : scopes) { common_ancestor = FindCommonAncestor(common_ancestor.value(), scope); if (!common_ancestor.has_value()) { - return std::nullopt; + return c10::nullopt; } } @@ -410,7 +410,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { } } - return std::nullopt; + return c10::nullopt; } std::shared_ptr FunctionExtractor::ConstructFuncGraph( diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 6a1e3b08f3b9a8..b28de0fdee4cd5 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -52,7 +52,7 @@ std::deque findSubModuleAttr( Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { auto schema = function->getSchema(); auto args = schema.arguments(); - args.emplace_back(name, nullptr, std::nullopt, attr); + args.emplace_back(name, nullptr, c10::nullopt, attr); auto new_schema = FunctionSchema( schema.name(), schema.overload_name(), diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp index cd975d0375fcbb..6c064b70ae614f 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp @@ -46,7 +46,7 @@ Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) { aten::slice, {index, graph->insertConstant( - scalar_to_tensor(at::Scalar(0)), std::nullopt, slice->scope()), + scalar_to_tensor(at::Scalar(0)), c10::nullopt, slice->scope()), start, end, step}); diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp index 7a98567a529bee..6110954990455b 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp @@ -84,7 +84,7 @@ std::optional EncapsulatePatternIntoSubblock(Node* n) { return EncapsulateInplaceIndexPutForONNX(n); } } - return std::nullopt; + return c10::nullopt; } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 18c31ea656610d..b468e739a03f3d 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -16,7 +16,7 @@ #include #endif -#include +#include #if defined(_MSC_VER) #include @@ -105,14 +105,14 @@ std::optional fusibleExpandTo( at::IntArrayRef from, at::IntArrayRef to) { if (from.size() > to.size()) { - return std::nullopt; + return c10::nullopt; } for (const auto i : c10::irange(from.size())) { auto fdim = from[from.size() - 1 - i]; auto tdim = to[to.size() - 1 - i]; if (fdim != 1 && fdim != tdim) { - return std::nullopt; + return c10::nullopt; } } @@ -168,7 +168,7 @@ void fuseBroadcast(Block* b) { .sizes() .concrete_sizes() .value()); // to - if (axis == std::nullopt) { + if (axis == c10::nullopt) { continue; } diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 009566499275b4..427e5771a9f0f7 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -100,7 +100,7 @@ static bool IsImplicitCastSupported(const NodeKind& nodeKind) { static std::optional PromoteScalarTypes( const std::vector& types) { if (types.empty()) { - return std::nullopt; + return c10::nullopt; } auto st = types[0]; for (const auto i : c10::irange(1, types.size())) { @@ -131,9 +131,9 @@ static std::optional PromoteScalarTypesWithCategory( return 0; }; - if (std::nullopt == typeFromScalar) { + if (c10::nullopt == typeFromScalar) { return typeFromTensor; - } else if (std::nullopt == typeFromTensor) { + } else if (c10::nullopt == typeFromTensor) { return typeFromScalar; } @@ -155,7 +155,7 @@ static std::optional InferExpectedScalarType(const Node* n) { if (auto* tensor_type = input->type()->castRaw()) { return tensor_type->scalarType(); } - return std::nullopt; + return c10::nullopt; }; auto emplace_type_from_scalar = [&typesFromTensors, &typesFromScalars](at::ScalarType scalar_type) { @@ -252,7 +252,7 @@ static std::optional InferExpectedScalarType(const Node* n) { } }); - std::optional st = std::nullopt; + std::optional st = c10::nullopt; const auto output_st = get_scalar_type(n->output()); if (IsComparisonOp(n->kind())) { @@ -313,7 +313,7 @@ static void UpdateScalarTypeForInputs( for (auto input : n->inputs()) { auto input_tensor_type = input->type()->cast(); auto input_scalar_type = - input_tensor_type ? input_tensor_type->scalarType() : std::nullopt; + input_tensor_type ? input_tensor_type->scalarType() : c10::nullopt; // We skip the 'condition' input (i.e., the first input) in case of // onnx:Where operator. @@ -393,7 +393,7 @@ static void RecoverScalarTypeForOutput( static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) { TORCH_INTERNAL_ASSERT(n->outputs().size() == 1); if (n->output()->type()->cast() == nullptr || - n->output()->type()->cast()->scalarType() == std::nullopt) { + n->output()->type()->cast()->scalarType() == c10::nullopt) { // skip LowPrecisionCast if op output type is null. return; } @@ -401,7 +401,7 @@ static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) { n->output()->type()->cast()->scalarType().value(); for (size_t i = 0; i < n->inputs().size(); ++i) { if (n->input(i)->type()->cast() == nullptr || - n->input(i)->type()->cast()->scalarType() == std::nullopt) { + n->input(i)->type()->cast()->scalarType() == c10::nullopt) { // skip LowPrecisionCast if any op input type node is null. return; } diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 3691f0bf7b09ce..65d065adeb2b57 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -98,7 +98,7 @@ c10::ShapeSymbol ONNXDimToShapeSymbol( if (dim.has_dim_value()) { return c10::ShapeSymbol::fromStaticSize(dim.dim_value()); } - std::optional sym = std::nullopt; + std::optional sym = c10::nullopt; if (dim.has_dim_param()) { // If this param is already known, assign the same Symbol. GRAPH_UPDATE("Got dim_param:", dim.dim_param()); @@ -267,7 +267,7 @@ Value* CloneValueFromListConstruct( // is preserved. If the elemtype is Int, insert a onnx::Concat node into // the graph. TypePtr elem = v->type()->castRaw()->getElementType(); - std::optional scalar_type = std::nullopt; + std::optional scalar_type = c10::nullopt; if (elem->cast()) { scalar_type = at::kLong; if (isValidToTransformToONNXConcatNode(v->node())) { @@ -332,7 +332,7 @@ Node* CloneNodeToGraph( // Try to lookup input value and insert it into the graph. // If the input value is unknown, set it to graph input in the new // graph, and copy over metadata, such as datatype and shape. - ::std::optional val = ::std::nullopt; + ::std::optional val = ::c10::nullopt; auto v0 = params_dict.find(v->debugName()); if (v0 != params_dict.end()) { val = v0->second.toTensor(); @@ -420,13 +420,13 @@ void ConvertGraphToONNXProto( std::optional ComputeConstantFolding(Node* n, int opset_version) { if (n->inputs().empty()) { - return std::nullopt; + return c10::nullopt; } std::vector inputTensorValues; for (auto i : c10::irange(n->inputs().size())) { if (TensorTypePtr input_type = n->input(i)->type()->cast()) { if (!ConstantValueMap::HasValue(n->input(i)->debugName())) { - return std::nullopt; + return c10::nullopt; } auto tensor_value = ConstantValueMap::GetValue(n->input(i)->debugName()).value(); @@ -434,7 +434,7 @@ std::optional ComputeConstantFolding(Node* n, int opset_version) { } } if (inputTensorValues.size() < n->inputs().size()) { - return std::nullopt; + return c10::nullopt; } try { return onnx_constant_fold::runTorchBackendForOnnx( @@ -443,7 +443,7 @@ std::optional ComputeConstantFolding(Node* n, int opset_version) { auto ex_str = std::string(ex.what()); ex_str = ex_str.substr(0, ex_str.find('\n')); TORCH_WARN("Constant folding in symbolic shape inference fails: ", ex_str); - return std::nullopt; + return c10::nullopt; } } @@ -500,7 +500,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( std::numeric_limits::max() / input_shape.static_size()) { TORCH_WARN( "ComputeShapeFromReshape(), shape_ratio overflows, skip shape inference."); - return std::nullopt; + return c10::nullopt; } else { shape_ratio *= static_cast(input_shape.static_size()); } @@ -523,7 +523,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( } else { auto value = target_shape.value(); if (sym_map.find(value) == sym_map.end()) { - return std::nullopt; + return c10::nullopt; } sym_map[value]--; if (sym_map[value] == 0) { @@ -535,7 +535,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape( // sym_map is used to match shape symbols between the input and shape. // If there is a mismatch, the output shape cannot be estimated. if (!sym_map.empty()) { - return std::nullopt; + return c10::nullopt; } TORCH_INTERNAL_ASSERT( @@ -565,7 +565,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromExpand( const std::vector& reshape) { for (const auto& it : reshape) { if (it < 0) { - return std::nullopt; + return c10::nullopt; } } std::vector<::c10::ShapeSymbol> final_shape; @@ -607,7 +607,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromTile( "ONNX Tile input shapes do not match."); for (const auto& it : reshape) { if (it < 0) { - return std::nullopt; + return c10::nullopt; } } std::vector<::c10::ShapeSymbol> final_shape; @@ -688,7 +688,7 @@ std::optional> GetValueFromListConstructNode( } return lc_node->inputs().size() == shape_size.size() ? std::optional>(shape_size) - : std::nullopt; + : c10::nullopt; } void SetShapeValueFromListConstructNode(Node* lc_node) { diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index d889295dca19e2..7390bea56e77b0 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -655,11 +655,11 @@ void UnpackQuantizedTensorInputs(std::shared_ptr& graph) { auto input_scale = graph->insertInput(index + 1, input_name + "_scale") ->setType(TensorType::create( - at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt)); + at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); auto input_zero_point = graph->insertInput(index + 2, input_name + "_zero_point") ->setType(TensorType::create( - at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt)); + at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); std::vector converted{input_value, input_scale, input_zero_point}; auto input_tuple = graph->prependNode(graph->createTuple(converted))->output(); diff --git a/torch/csrc/jit/passes/peephole_dict_idioms.cpp b/torch/csrc/jit/passes/peephole_dict_idioms.cpp index 171b787d17b048..d3a5cfa36261b0 100644 --- a/torch/csrc/jit/passes/peephole_dict_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_dict_idioms.cpp @@ -34,7 +34,7 @@ class DictNodeImpl : public DictNodeImplBase { auto key_opt = toIValue(dict_creation_node->input(i)); // Key is not constant if we cannot convert to IValue - if (key_opt == std::nullopt) { + if (key_opt == c10::nullopt) { has_non_const_key_ = true; continue; } @@ -129,7 +129,7 @@ class DictNode { if (impl_ && impl_->contains(key)) { return impl_->get(key); } - return std::nullopt; + return c10::nullopt; } private: @@ -185,14 +185,14 @@ class PeepholeOptimizeDictIdiomsImpl { const DictNode& dict_node = getDictNode(dict_creation_node); auto key_opt = toIValue(key); // Key is not constant if we cannot convert to IValue - if (key_opt == std::nullopt) { - return std::nullopt; + if (key_opt == c10::nullopt) { + return c10::nullopt; } IValue key_ival = *key_opt; if (dict_node.canOptimize()) { return dict_node.getOrNullopt(key_ival); } - return std::nullopt; + return c10::nullopt; } std::optional computeLen(Node* dict_creation_node) { @@ -200,13 +200,13 @@ class PeepholeOptimizeDictIdiomsImpl { if (dict_node.canOptimize()) { return static_cast(dict_node.size()); } - return std::nullopt; + return c10::nullopt; } bool optimizeLen(Node* len_node, Node* creation_node) { if (creation_node->kind() == prim::DictConstruct) { auto len = computeLen(creation_node); - if (len != std::nullopt) { + if (len != c10::nullopt) { WithInsertPoint guard(len_node); len_node->output()->replaceAllUsesWith(graph_->insertConstant(len)); return true; @@ -219,7 +219,7 @@ class PeepholeOptimizeDictIdiomsImpl { if (creation_node->kind() == prim::DictConstruct) { auto key = getitem_node->input(1); auto value = getValueFromDict(creation_node, key); - if (value != std::nullopt) { + if (value != c10::nullopt) { getitem_node->output()->replaceAllUsesWith(*value); return true; } diff --git a/torch/csrc/jit/passes/peephole_list_idioms.cpp b/torch/csrc/jit/passes/peephole_list_idioms.cpp index f644fe4f1de1c8..9c106e13edf1f8 100644 --- a/torch/csrc/jit/passes/peephole_list_idioms.cpp +++ b/torch/csrc/jit/passes/peephole_list_idioms.cpp @@ -21,7 +21,7 @@ static std::optional normalizeIndex(int64_t index, size_t len) { if (index >= 0 && index < static_cast(len)) { return index; } else { - return std::nullopt; + return c10::nullopt; } } @@ -136,7 +136,7 @@ struct ListLenRefiner { return maybe_refinement->second; } } - return std::nullopt; + return c10::nullopt; } std::shared_ptr graph_; @@ -199,8 +199,8 @@ struct PeepholeOptimizeListIdiomsImpl { auto step_val = toIValue(slice_node->input(3)); // All args must be constant to apply this optimization. - if (start_val == std::nullopt || end_val == std::nullopt || - step_val == std::nullopt) { + if (start_val == c10::nullopt || end_val == c10::nullopt || + step_val == c10::nullopt) { return false; } diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index 7eea68eb106546..8a74ec01086a58 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -325,7 +325,7 @@ std::optional getClampScalarInputUse(Value* v) { } } } - return std::nullopt; + return c10::nullopt; } void cloneMethod( @@ -503,7 +503,7 @@ std::optional> getFixedQParams(Node* n) { if (isAtenFunc(n, fixed_qparam_funcs)) { return _fixed_qparams_map.at(n->kind()); } - return std::nullopt; + return c10::nullopt; } bool userDefinedCallFunction(Node* n) { @@ -534,13 +534,13 @@ bool nodeQuantizable(Node* n, QuantType quant_type) { bool useQuantizable(const Use& use, QuantType quant_type) { if (quant_type == QuantType::STATIC) { for (const auto& func_input : _observe_inputs_aten_func) { - if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) { + if (matchAtenFuncToUse(use, func_input.func_name, c10::nullopt)) { return use.offset == static_cast(func_input.arg_index); } } for (const auto& func_input : _observe_inputs_call_func) { - if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) { + if (matchCallFuncToUse(use, func_input.func_name, c10::nullopt)) { return use.offset == static_cast(func_input.arg_index); } } @@ -653,7 +653,7 @@ std::optional getInvokedModuleOpt( if (m.attr(p).isModule()) { m = m.attr(p).toModule(); } else { - return std::nullopt; + return c10::nullopt; } } return m; @@ -691,7 +691,7 @@ std::optional getModuleName(Value* value) { if (type && type->name()) { return removeTorchMangle(type->name()->qualifiedName()); } - return std::nullopt; + return c10::nullopt; } static bool is_module( diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h index 21efbff7aa6941..680e3c7ca43d52 100644 --- a/torch/csrc/jit/passes/quantization/helper.h +++ b/torch/csrc/jit/passes/quantization/helper.h @@ -150,7 +150,7 @@ TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self); // Given an CallMethod node, get the module instance corresponding // to the instance Value if the instance is a module, otherwise return -// std::nullopt +// c10::nullopt std::optional getInvokedModuleOpt( const Module& module, Node* n, diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index f906efacceca7b..145448210958ac 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -49,7 +49,7 @@ void fillQConfigMap( const QConfigDict& qconfig_dict, ModuleQConfigMap& map, const std::string& key = "", - const std::optional& parent_qconfig = std::nullopt) { + const std::optional& parent_qconfig = c10::nullopt) { std::optional qconfig; if (qconfig_dict.find(key) != qconfig_dict.end()) { GRAPH_DEBUG("Got module config for key:", key); @@ -1414,7 +1414,7 @@ InsertObserversHelper::insertObserversFor( if (!isObserved(v, block_observed_values)) { block_output_observers.emplace_back(getObserverFor(v)); } else { - block_output_observers.emplace_back(std::nullopt); + block_output_observers.emplace_back(c10::nullopt); } } } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 3d24834261d2a0..92fb2fc79bcc91 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -234,7 +234,7 @@ std::optional findObserverName(Value* v) { return module_instance->node()->s(attr::name); } } - return std::nullopt; + return c10::nullopt; } bool isPlaceholderObserver(Value* observer) { @@ -268,7 +268,7 @@ std::optional getEmbeddingBagObsName( auto op_name = observer_module.attr("custom_op").toStringRef(); return isPlaceholderObserver(observer) ? std::move(op_name) : ""; } - return std::nullopt; + return c10::nullopt; } bool isEmbeddingBagOp( @@ -792,7 +792,7 @@ class InsertQuantDeQuantHelper { const std::vector& inputs, bool is_scalar = false, const std::optional>& qparams_opt = - std::nullopt); + c10::nullopt); bool isQuantized(Value* v) { return quantized_values_.count(v) != 0; @@ -1269,7 +1269,7 @@ std::optional> getDequantizedInputs(Value* output) { return inputs; } } - return std::nullopt; + return c10::nullopt; } void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) { diff --git a/torch/csrc/jit/passes/remove_mutation.h b/torch/csrc/jit/passes/remove_mutation.h index 1242555358f771..be8fc12b11f3d7 100644 --- a/torch/csrc/jit/passes/remove_mutation.h +++ b/torch/csrc/jit/passes/remove_mutation.h @@ -11,7 +11,7 @@ namespace jit { struct TORCH_API MutationRemover { MutationRemover( std::shared_ptr graph, - std::optional> mutation_filter = std::nullopt) + std::optional> mutation_filter = c10::nullopt) : mutation_filter_(mutation_filter), aliasDb_(nullptr), graph_(std::move(graph)) {} @@ -71,7 +71,7 @@ TORCH_API bool RemoveListMutation(const std::shared_ptr& graph); // return true if graph is modified TORCH_API bool RemoveTensorMutation( const std::shared_ptr& graph, - std::optional> mutation_filter = std::nullopt); + std::optional> mutation_filter = c10::nullopt); // Replaces in-place aten activation ops with their functional equivalence TORCH_API bool InplaceToFunctionalActivation( diff --git a/torch/csrc/jit/passes/replacement_of_old_operators.cpp b/torch/csrc/jit/passes/replacement_of_old_operators.cpp index 2d3b3a2aba7fc5..38255ad1418771 100644 --- a/torch/csrc/jit/passes/replacement_of_old_operators.cpp +++ b/torch/csrc/jit/passes/replacement_of_old_operators.cpp @@ -30,7 +30,7 @@ struct OldOpsReplacerWithUpgraders { Node* node = graph_it.next(); while (node) { // load the schema name for this op - std::optional schema_name = std::nullopt; + std::optional schema_name = c10::nullopt; if (auto op_schema = node->maybeSchema()) { schema_name = getFullSchemaName(*op_schema); } else { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 7290e1936128c7..abc7bb6411dbae 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -151,7 +151,7 @@ bool containsTensorType(const TypePtr& t) { } // for each node in the schema with type Tensor, extract the T type -// returns std::nullopt if any Tensor in the schema does not have a known +// returns c10::nullopt if any Tensor in the schema does not have a known // shape ignores non-tensor in the list of inputs std::optional> gatherTensorTypes( Node* node, @@ -160,26 +160,26 @@ std::optional> gatherTensorTypes( auto schema_opt = node->maybeSchema(); if (!schema_opt) { - return std::nullopt; + return c10::nullopt; } auto& schema = *schema_opt; auto& args = schema.arguments(); // can't handle varargs primitives because we don't know what should be a // Tensor if (schema.is_vararg()) { - return std::nullopt; + return c10::nullopt; } for (const auto i : c10::irange(args.size())) { if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) { - return std::nullopt; + return c10::nullopt; } else if (args[i].type()->isSubtypeOf(*TensorType::get())) { if (auto type = node->input(i)->type()->cast()) { if (complete && !type->isComplete()) { - return std::nullopt; + return c10::nullopt; } tensor_types.push_back(type); } else { - return std::nullopt; + return c10::nullopt; } } else /* non-tensor type */ { continue; @@ -217,7 +217,7 @@ std::optional getPromotedTypeForArithmeticOp(Node* node) { auto dtt = node->inputs()[i]->type()->expect(); auto inputDtype = dtt->scalarType(); if (!dtt || !inputDtype) { - return std::nullopt; + return c10::nullopt; } if (dtt->dim() && *dtt->dim() > 0) { dimmed = unionScalarTypes(dimmed, *inputDtype); @@ -552,7 +552,7 @@ class ShapePropagator : public PropertyPropBase { tryScalarTypeFromJitType(*input_base_type); if (auto grad_index = node->schema().argumentIndexWithName("dtype")) { auto inp = toIValue(node->inputs().at(*grad_index)); - if (inp == std::nullopt) { + if (inp == c10::nullopt) { return; } else if (!inp->isNone()) { default_type = inp->toScalarType(); @@ -562,14 +562,14 @@ class ShapePropagator : public PropertyPropBase { at::Device default_device = at::kCPU; if (auto device_index = node->schema().argumentIndexWithName("device")) { auto inp = toIValue(node->inputs().at(*device_index)); - if (inp == std::nullopt) { + if (inp == c10::nullopt) { return; } else if (!inp->isNone()) { default_device = inp->toDevice(); } } node->output()->setType(TensorType::create( - default_type, default_device, dims, /*requires_grad=*/std::nullopt)); + default_type, default_device, dims, /*requires_grad=*/c10::nullopt)); } // returns whether any such values were found @@ -612,10 +612,10 @@ class ShapePropagator : public PropertyPropBase { if (typ->isSubtypeOf(*IntType::get()) || typ->isSubtypeOf(*BoolType::get())) { node->output()->setType(TensorType::create( - at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt)); + at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) { node->output()->setType(TensorType::create( - at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt)); + at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt)); } return; } @@ -750,7 +750,7 @@ class ShapePropagator : public PropertyPropBase { if (input_node->kind() == prim::ListConstruct) { return input_node->inputs().size(); } - return std::nullopt; + return c10::nullopt; } // is it ok to try to run the op @@ -778,7 +778,7 @@ class ShapePropagator : public PropertyPropBase { auto max_dims = any_type->dim(); for (auto& type : tensor_types) { if (!max_dims || !type->dim()) { - max_dims = std::nullopt; + max_dims = c10::nullopt; } else { max_dims = std::max(*max_dims, *type->dim()); } @@ -787,7 +787,7 @@ class ShapePropagator : public PropertyPropBase { t, any_type->device(), max_dims, - /*requires_grad=*/std::nullopt); + /*requires_grad=*/c10::nullopt); }; using type_vec_t = std::vector; @@ -1245,7 +1245,7 @@ class ShapePropagator : public PropertyPropBase { int64_t num_reduced_dim = 0, bool upcast_integer = false, std::optional opt_dtype = - std::nullopt) -> type_vec_t { + c10::nullopt) -> type_vec_t { if (auto type = node->input(0)->type()->cast()) { if (!type->scalarType() || !type->dim()) { return {}; @@ -1418,7 +1418,7 @@ class ShapePropagator : public PropertyPropBase { : maybe_dtype_option->toScalarType()); return {TensorType::create( - dtype, device, dim, /*requires_grad=*/std::nullopt)}; + dtype, device, dim, /*requires_grad=*/c10::nullopt)}; }; static const auto factory_like_with_ndim = [](Node* node, @@ -1448,7 +1448,7 @@ class ShapePropagator : public PropertyPropBase { } return {TensorType::create( - in_type, in_dev, dim, /*requires_grad=*/std::nullopt)}; + in_type, in_dev, dim, /*requires_grad=*/c10::nullopt)}; }; // Requirements: @@ -1748,7 +1748,7 @@ class ShapePropagator : public PropertyPropBase { if (auto dtype_index = node->schema().argumentIndexWithName("dtype")) { auto inp = toIValue(node->inputs().at(*dtype_index)); - if (inp == std::nullopt) { + if (inp == c10::nullopt) { return nullptr; } if (!inp->isNone()) { @@ -1758,7 +1758,7 @@ class ShapePropagator : public PropertyPropBase { if (auto device_index = node->schema().argumentIndexWithName("device")) { auto inp = toIValue(node->inputs().at(*device_index)); - if (inp == std::nullopt) { + if (inp == c10::nullopt) { return nullptr; } if (!inp->isNone()) { @@ -1769,7 +1769,7 @@ class ShapePropagator : public PropertyPropBase { default_type, default_device, type->dim(), - /*requires_grad=*/std::nullopt)); + /*requires_grad=*/c10::nullopt)); } } return nullptr; diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index 6ac9576a8e2bc5..951c093cefe55a 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -61,7 +61,7 @@ namespace jit { // %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2) // // x.view([5, y.size(0), inp]) -// will have inputs equal to [5, SS(-1), std::nullopt] +// will have inputs equal to [5, SS(-1), c10::nullopt] struct ShapeArg : public std:: @@ -73,17 +73,17 @@ struct ShapeArg } ShapeArg(int64_t int_value) { - this->first = std::nullopt; + this->first = c10::nullopt; this->second = int_value; } ShapeArg(c10::ShapeSymbol ss) { if (ss.is_static()) { - this->first = std::nullopt; + this->first = c10::nullopt; this->second = ss.value(); } else { this->first = ss; - this->second = std::nullopt; + this->second = c10::nullopt; } } @@ -97,8 +97,8 @@ struct ShapeArg private: ShapeArg() { - this->first = std::nullopt; - this->second = std::nullopt; + this->first = c10::nullopt; + this->second = c10::nullopt; } }; @@ -215,7 +215,7 @@ std::optional normIndex(int64_t index, size_t len) { if (index >= 0 && index < static_cast(len)) { return index; } else { - return std::nullopt; + return c10::nullopt; } } @@ -608,7 +608,7 @@ struct SymbolicShapeOpAnalyzer { std::optional> run( std::vector& inputs) { if (!shape_compute_graph_) { - return std::nullopt; + return c10::nullopt; } inputs_ = inputs; substituteConstantInputs(); @@ -788,7 +788,7 @@ c10::SymbolicShape combine_bounds( c10::SymbolicShape& upper_bound) { // TODO: At some point we might want to add support for dynamic dims TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank()); - if (lower_bound.rank() == std::nullopt) { + if (lower_bound.rank() == c10::nullopt) { return c10::SymbolicShape(); } std::vector merged_shapes; @@ -837,14 +837,14 @@ struct SymbolicShapeGraphAnalyzer { return use.user->kind() == aten::cat; })) { GRAPH_DEBUG("Non cat list use ", getHeader(curr)); - return std::nullopt; + return c10::nullopt; } continue; } if (!partial_evaluated_graphs.count(curr)) { GRAPH_DEBUG("No graph ", getHeader(curr)); - return std::nullopt; + return c10::nullopt; } auto outputs = curr->outputs(); @@ -852,13 +852,13 @@ struct SymbolicShapeGraphAnalyzer { auto tt = v->type()->cast(); if (!tt) { GRAPH_DEBUG("Non tensor node", getHeader(curr)); - return std::nullopt; + return c10::nullopt; } auto symbolic_sizes = tt->symbolic_sizes(); // TODO: dont require # of dimensions of tensors set ? if (!symbolic_sizes.rank()) { GRAPH_DEBUG("No rank on output ", getHeader(curr)); - return std::nullopt; + return c10::nullopt; } } auto partial_eval_graph = partial_evaluated_graphs[curr]; @@ -1133,11 +1133,11 @@ calculateSymbolicShapesOnOp( const FunctionSchema* schema, const std::vector& inputs) { auto bounded_graphs = boundedGraphsForSchema(*schema); - auto has_shape_compute = shapeComputeGraphForSchema(*schema) != std::nullopt; - if (!has_shape_compute && bounded_graphs == std::nullopt) { + auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt; + if (!has_shape_compute && bounded_graphs == c10::nullopt) { // Avoid doing all this work for functions that don't have a // supported schema - return std::nullopt; + return c10::nullopt; } if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) { @@ -1172,7 +1172,7 @@ calculateSymbolicShapesOnOp( cache_shape_function(schema, inputs, merged_res); return merged_res; } - return std::nullopt; + return c10::nullopt; } auto op_analyzer = SymbolicShapeOpAnalyzer(schema); diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.cpp b/torch/csrc/jit/passes/symbolic_shape_cache.cpp index d01d11983a622c..4a742b3f5f6351 100644 --- a/torch/csrc/jit/passes/symbolic_shape_cache.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_cache.cpp @@ -120,7 +120,7 @@ get_cached_shape_function( get_cache_key(schema, arg_vec, ss_map, /* deep_copy */ false); auto cached_ret_vec = shapeCache.Get(cache_key); if (cached_ret_vec == nullptr) { - return std::nullopt; + return c10::nullopt; } // Decanonicalize the return values auto inverse_ss_map = std::unordered_map(); @@ -148,7 +148,7 @@ void CanonicalizedSymbolicShape::init( std::unordered_map& ss_map) { auto sizes = orig_shape.sizes(); if (!sizes) { - values_ = std::nullopt; + values_ = c10::nullopt; return; } values_ = std::vector(); diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp index 3cf23732a9ad65..9c213f2480d51d 100644 --- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp @@ -190,7 +190,7 @@ TryGeneralizeInputDimensionsToSymbolicShapes( } auto tt = v->type()->expectRef(); if (!tt.sizes().isComplete() || !tt.strides().isComplete()) { - return std::nullopt; + return c10::nullopt; } input_striding.push_back(summarizeInputStrides(tt)); std::vector shape_vec = *tt.symbolic_sizes().sizes(); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 684f47f4efb93a..c9b9b974600dc4 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -782,7 +782,7 @@ class TensorExprFuser { std::optional tryMerge(Node* fusion_group, Node* to_merge) { if (!canMerge(fusion_group, to_merge)) { - return std::nullopt; + return c10::nullopt; } std::vector nodes_to_merge = {to_merge}; @@ -799,7 +799,7 @@ class TensorExprFuser { GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) { GRAPH_UPDATE("Failed to move because of AliasDB checks!"); - return std::nullopt; + return c10::nullopt; } move_point = n; } diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 6082058952ce9e..4c081200715a71 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -196,7 +196,7 @@ std::optional toIValueProp(const Value* v) { genericList.push_back(*elem); } else { // One of the list elements isn't constant. - return std::nullopt; + return c10::nullopt; } } @@ -213,7 +213,7 @@ std::optional toIValueProp(const Value* v) { return IValue( fmap(genericList, [](const IValue& v) { return v.toTensor(); })); } else { - return std::nullopt; + return c10::nullopt; } } @@ -222,7 +222,7 @@ std::optional toIValueProp(const Value* v) { return maybe_stack->at(0); } } - return std::nullopt; + return c10::nullopt; } // batch_norm and instance_norm have incorrect annotations, because diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h index 1d2292fe90c5ba..da5584f9d4bd35 100644 --- a/torch/csrc/jit/passes/utils/memory_dag.h +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -2,12 +2,12 @@ #include #include +#include #include #include #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index f4dfc4ce99c940..377621c04b6dbf 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -429,7 +429,7 @@ Node* createSingletonSubgraphAndUpdateAliasing( Symbol subgraphKind, AliasDb& db) { return executeSubgraphMergeAndUpdateAliasing( - to_merge, std::nullopt, db, [&]() { + to_merge, c10::nullopt, db, [&]() { return createSingletonSubgraph(to_merge, subgraphKind); }); } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 862aaba7d7dc14..1bfc6c94a707f4 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -157,7 +157,7 @@ std::optional toTypeInferredIValueOptional(py::handle input) { try { return toTypeInferredIValue(input); } catch (const c10::Error& e) { - return std::nullopt; + return c10::nullopt; } } } // anonymous namespace @@ -219,7 +219,7 @@ void initJITBindings(PyObject* module) { "_jit_shape_compute_graph_for_node", [](Node* n) -> std::optional> { if (!n->maybeSchema()) { - return std::nullopt; + return c10::nullopt; } return shapeComputeGraphForSchema(n->schema()); }) @@ -227,7 +227,7 @@ void initJITBindings(PyObject* module) { "_jit_decomposition_graph_for_node", [](Node* n) -> std::optional> { if (!n->maybeSchema()) { - return std::nullopt; + return c10::nullopt; } return GetDecomposition(n->schema()); }) @@ -1165,7 +1165,7 @@ void initJITBindings(PyObject* module) { c10::kCPU, std::vector{1}, std::vector{1}, - std::nullopt)); + c10::nullopt)); } } }) @@ -1680,7 +1680,7 @@ void initJITBindings(PyObject* module) { [op, symbol, allow_numbers_as_tensors]( c10::DispatchKey dk_, py::args args, py::kwargs kwargs) { std::optional dk = - std::make_optional(dk_); + c10::make_optional(dk_); ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors); return _get_operation_for_overload_or_packet( {op}, symbol, args, kwargs, /*is_overload*/ true, dk); diff --git a/torch/csrc/jit/python/module_python.h b/torch/csrc/jit/python/module_python.h index b1ddf6f37c6786..5c7fbbb42d6cfc 100644 --- a/torch/csrc/jit/python/module_python.h +++ b/torch/csrc/jit/python/module_python.h @@ -14,7 +14,7 @@ inline std::optional as_module(py::handle obj) { if (py::isinstance(obj, ScriptModule)) { return py::cast(obj.attr("_c")); } - return std::nullopt; + return c10::nullopt; } inline std::optional as_object(py::handle obj) { @@ -29,7 +29,7 @@ inline std::optional as_object(py::handle obj) { if (py::isinstance(obj, RecursiveScriptClass)) { return py::cast(obj.attr("_c")); } - return std::nullopt; + return c10::nullopt; } } // namespace torch::jit diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 2dbcfee423ae7a..a731640223c096 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -754,7 +754,7 @@ std::pair, Stack> getOpWithStack( std::shared_ptr op = operations.at(0); // Create a stack full of the arguments and keyword arguments. stack = createStackForSchema( - op->schema(), std::move(args), kwargs, std::nullopt); + op->schema(), std::move(args), kwargs, c10::nullopt); return std::make_pair(std::move(op), std::move(stack)); } else { @@ -762,7 +762,7 @@ std::pair, Stack> getOpWithStack( std::shared_ptr found_op = nullptr; for (const auto& op : operations) { try { - stack = createStackForSchema(op->schema(), args, kwargs, std::nullopt); + stack = createStackForSchema(op->schema(), args, kwargs, c10::nullopt); found_op = op; break; } catch (schema_match_error& error) { diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index cd8a7335167d4a..23fda5b0d784ec 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -36,8 +36,8 @@ #include #endif #include +#include #include -#include #include #include @@ -62,7 +62,7 @@ void clear_registered_instances(void* ptr); TORCH_PYTHON_API IValue toIValue( py::handle obj, const TypePtr& type, - std::optional N = std::nullopt); + std::optional N = c10::nullopt); TORCH_PYTHON_API py::object toPyObject(IValue ivalue); @@ -111,7 +111,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper explicit PythonFutureWrapper( c10::intrusive_ptr fut, - std::optional unwrap_func = std::nullopt) + std::optional unwrap_func = c10::nullopt) : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {} explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete; @@ -1205,7 +1205,7 @@ inline std::optional maybeTorchFunctionDispatch( /*module_name=*/qualname.prefix().c_str())); } - return std::nullopt; + return c10::nullopt; } inline py::object invokeScriptFunctionFromPython( @@ -1219,7 +1219,7 @@ inline py::object invokeScriptFunctionFromPython( callee, args, kwargs, - /*self=*/std::nullopt, + /*self=*/c10::nullopt, [&](Graph& graph, const MatchedSchema& match) { return graph.insertFunctionCall(&callee, match); }); @@ -1255,7 +1255,7 @@ TORCH_PYTHON_API py::object invokeOperatorFromPython( const std::vector>& operations, py::args args, const py::kwargs& kwargs, - std::optional dk = std::nullopt); + std::optional dk = c10::nullopt); TORCH_PYTHON_API std::optional _maybe_handle_torch_function( const std::string& ns, @@ -1276,6 +1276,6 @@ TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet( py::args args, const py::kwargs& kwargs, bool is_overload, - std::optional dk = std::nullopt); + std::optional dk = c10::nullopt); } // namespace torch::jit diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index c80208b9d00df8..79957999f543dc 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -138,17 +138,17 @@ std::optional ConcretePythonOp::autogradFunction() const { auto r = py::getattr(obj, "__self__", py::none()); if (r.is_none()) - return std::nullopt; + return c10::nullopt; auto apply = py::getattr(r, "apply", py::none()); if (apply.is_none()) - return std::nullopt; + return c10::nullopt; auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE); if (PyErr_Occurred()) throw py::error_already_set(); if (c) - return std::nullopt; + return c10::nullopt; return THPObjectPtr(r.release().ptr()); } diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index 6d0bf1afc3b06f..4cdc8e430b9a81 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -31,7 +31,7 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { return torch::jit::tryToInferType(py_obj_); } - IValue toIValue(const TypePtr& type, std::optional N = std::nullopt) + IValue toIValue(const TypePtr& type, std::optional N = c10::nullopt) override { pybind11::gil_scoped_acquire ag; return torch::jit::toIValue(py_obj_, type, N); diff --git a/torch/csrc/jit/python/python_list.h b/torch/csrc/jit/python/python_list.h index f73cb5048529bd..b5bb88b3aeb20d 100644 --- a/torch/csrc/jit/python/python_list.h +++ b/torch/csrc/jit/python/python_list.h @@ -4,10 +4,10 @@ #include #include #include +#include #include #include #include -#include #include namespace torch::jit { @@ -175,7 +175,7 @@ class ScriptList final { // Remove and return the element at the specified index from the list. If no // index is passed, the last element is removed and returned. - IValue pop(std::optional idx = std::nullopt) { + IValue pop(std::optional idx = c10::nullopt) { IValue ret; if (idx) { diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index c5d48f5cbe7474..d6f014759c05e9 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -28,7 +28,7 @@ std::optional as_function(const py::object& obj) { if (py::isinstance(obj)) { return py::cast(obj); } - return std::nullopt; + return c10::nullopt; } FunctionSchema PythonValue::getSchema( @@ -66,8 +66,8 @@ FunctionSchema PythonValue::getSchema( args.emplace_back( /*name=*/*names_it, /*type=*/TensorType::get(), - /*N=*/std::nullopt, - /*default_value=*/std::nullopt, + /*N=*/c10::nullopt, + /*default_value=*/c10::nullopt, /*kwarg_only=*/false); } @@ -95,8 +95,8 @@ FunctionSchema PythonValue::getSchema( args.emplace_back( /*name=*/*names_it, /*type=*/std::move(*types_it), - /*N=*/std::nullopt, - /*default_value=*/std::nullopt, + /*N=*/c10::nullopt, + /*default_value=*/c10::nullopt, /*kwarg_only=*/false); } rets.push_back(Argument("0", std::move(ret_type), {}, {}, false)); @@ -240,10 +240,10 @@ std::shared_ptr CUDAPythonModuleValue::attr( // these APIs. if (field == "current_device" || field == "set_device") { return std::make_shared( - Symbol::cuda("_" + field), std::nullopt); + Symbol::cuda("_" + field), c10::nullopt); } else { return std::make_shared( - Symbol::cuda(field), std::nullopt); + Symbol::cuda(field), c10::nullopt); } } @@ -673,7 +673,7 @@ std::shared_ptr ModuleValue::tryGetAttr( if (const auto fnAttr = concreteType_->findFunctionAttribute(field)) { return std::make_shared(*fnAttr); } else if (const auto builtin = concreteType_->findBuiltinFunction(field)) { - return std::make_shared(*builtin, /*self=*/std::nullopt); + return std::make_shared(*builtin, /*self=*/c10::nullopt); } // 5. Check if it's an attribute of the original Python class that this @@ -1263,7 +1263,7 @@ std::shared_ptr toSugaredValue( py::module::import("torch.jit._builtins").attr("_find_builtin")(obj); if (!builtin_name.is_none()) { return std::make_shared( - Symbol::fromQualString(py::str(builtin_name)), std::nullopt); + Symbol::fromQualString(py::str(builtin_name)), c10::nullopt); } if (py::cast(py::module::import("torch._jit_internal") diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h index 508d95c8c538d0..cb397796c9f55e 100644 --- a/torch/csrc/jit/python/python_sugared_value.h +++ b/torch/csrc/jit/python/python_sugared_value.h @@ -32,7 +32,7 @@ std::optional as_function(const py::object& obj); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { PythonValue( py::object the_self, - std::optional rcb = std::nullopt, + std::optional rcb = c10::nullopt, Value* module_self = nullptr) : self(std::move(the_self)), rcb(std::move(rcb)), diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 0cd93887471e31..50d18b908107ee 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -14,7 +14,7 @@ namespace torch::jit { std::optional maybeConvertToString(const py::object& obj) { if (obj.is_none()) { - return std::nullopt; + return c10::nullopt; } std::stringstream ss; ss << py::str(obj); @@ -180,7 +180,7 @@ void initTreeViewBindings(PyObject* module) { return std::optional(property.setter().get().name()); } - return std::optional(std::nullopt); + return std::optional(c10::nullopt); }); py::class_(m, "ClassDef") diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 565f0b16363855..c46762a88615bb 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -220,7 +220,7 @@ std::optional tryCalculateDefaultParam( return toIValue(def_value, arg.type()); } } catch (...) { - return std::nullopt; + return c10::nullopt; } } @@ -702,13 +702,13 @@ void pyCompilationUnitDefine( const ResolutionCallback* rcb, const uint32_t _frames_up) { if (rcb && *rcb) { - cu.define(std::nullopt, src, pythonResolver(*rcb), nullptr); + cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr); } else { py::object py_default_rcb = py::module::import("torch._jit_internal") .attr("createResolutionCallbackFromFrame")(_frames_up); auto default_rcb = py_default_rcb.cast(); - cu.define(std::nullopt, src, pythonResolver(default_rcb), nullptr); + cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr); } } @@ -1315,7 +1315,7 @@ void initJitScriptBindings(PyObject* module) { "find_method", [](mobile::Module& m, const std::string& method_name) { auto method = m.find_method(method_name); - return method != std::nullopt; + return method != c10::nullopt; }, py::arg("method_name")) .def( @@ -1372,7 +1372,7 @@ void initJitScriptBindings(PyObject* module) { return std::optional( StrongFunctionPtr(std::move(self), fn)); } else { - return std::optional(std::nullopt); + return std::optional(c10::nullopt); } }) .def( @@ -2124,7 +2124,7 @@ void initJitScriptBindings(PyObject* module) { m.def( "_get_graph_executor_optimize", - [](std::optional new_setting = std::nullopt) { + [](std::optional new_setting = c10::nullopt) { bool old_value = getGraphExecutorOptimize(); if (new_setting) { setGraphExecutorOptimize(*new_setting); diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 047a35e417fff8..3987521f658f97 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -134,11 +134,11 @@ static std::optional> build_script_grad( auto graph = node->owningGraph(); auto maybe_schema = node->maybeSchema(); if (!maybe_schema) { - return std::nullopt; + return c10::nullopt; } auto compiled_graphs = gradientInfoForSchema(*maybe_schema); if (!compiled_graphs) { - return std::nullopt; + return c10::nullopt; } // Use forward graph to replace node in grad_desc.f value_list new_outputs; diff --git a/torch/csrc/jit/runtime/decomposition_registry.cpp b/torch/csrc/jit/runtime/decomposition_registry.cpp index 989a48bf06ab22..de205ed834c3bc 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry.cpp @@ -63,7 +63,7 @@ void loadDecompositionFunctions() { [&](const std::string& name) -> std::shared_ptr { return src; }, 1); compilation_unit->define( - std::nullopt, GetSerializedDecompositions(), resolver, nullptr); + c10::nullopt, GetSerializedDecompositions(), resolver, nullptr); loadModule(*compilation_unit); } @@ -117,7 +117,7 @@ std::optional> GetDecomposition( } GRAPH_DEBUG("Could not find schema: ", schema); - return std::nullopt; + return c10::nullopt; } std::optional GetDecompositionFunction( @@ -127,7 +127,7 @@ std::optional GetDecompositionFunction( GRAPH_DEBUG("Trying to find schema: ", schema); if (cache_it == schema_to_function.end()) { GRAPH_DEBUG("Could not find schema: ", schema); - return std::nullopt; + return c10::nullopt; } auto& func = toGraphFunction(*cache_it->second); // Simple Executor: diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 971e45e818ca6d..fce8d4a02e66c1 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -87,7 +87,7 @@ struct TORCH_API GraphExecutor { // current global fusion strategy settings. const ExecutionPlan& getPlanFor( Stack& inputs, - std::optional remaining_bailout_depth = std::nullopt); + std::optional remaining_bailout_depth = c10::nullopt); GraphExecutorState getDebugState(); void debugFlushCompilationCache(); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index 70069ac1907b0f..22a563f00be289 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -78,7 +78,7 @@ struct GraphExecutorImplBase { virtual const ExecutionPlan& getPlanFor( Stack& stack, - std::optional remaining_bailout_depth = std::nullopt) = 0; + std::optional remaining_bailout_depth = c10::nullopt) = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 0f6eb900e361df..18231173dd70e0 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -169,7 +169,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } void enterFrame(const Code& code, size_t base_pointer) { - frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt}); + frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt}); registers.resize(registers.size() + code.pImpl->register_size_); } @@ -181,7 +181,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { void callFunction( Function& f, Stack& stack, - std::optional bailOut = std::nullopt, + std::optional bailOut = c10::nullopt, bool next = true) { bool newFrame = f.call(stack, bailOut, [&](const Code& code) { enterFrame(code, stack.size() - code.num_inputs()); @@ -1244,7 +1244,7 @@ void InterpreterContinuation::operator()() { auto prev_dist_id = DistAutogradContainer::currentContextId(); DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_); #endif - if (tls_state_ != std::nullopt) { + if (tls_state_ != c10::nullopt) { at::ThreadLocalStateGuard g(*tls_state_); state.runAsync(stack); } else { diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index ffafd3ab096a9b..a28b1eb93526b5 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -1,6 +1,6 @@ #pragma once +#include #include -#include #include #include @@ -124,7 +124,7 @@ struct InterpreterContinuation { InterpreterState state_, Stack stack_, int64_t dist_autograd_context_id = 0, - std::optional tls_state = std::nullopt) + std::optional tls_state = c10::nullopt) : state(std::move(state_)), stack(std::move(stack_)), tls_state_(std::move(tls_state)) @@ -140,7 +140,7 @@ struct InterpreterContinuation { private: InterpreterState state; Stack stack; - std::optional tls_state_ = std::nullopt; + std::optional tls_state_ = c10::nullopt; #ifdef USE_DISTRIBUTED int64_t dist_autograd_context_id_; #endif diff --git a/torch/csrc/jit/runtime/jit_exception.h b/torch/csrc/jit/runtime/jit_exception.h index cb4f572a8bd3c0..34c3ebd6fca849 100644 --- a/torch/csrc/jit/runtime/jit_exception.h +++ b/torch/csrc/jit/runtime/jit_exception.h @@ -2,8 +2,8 @@ #include +#include #include -#include #include namespace torch::jit { @@ -11,8 +11,8 @@ namespace torch::jit { struct TORCH_API JITException : public std::runtime_error { explicit JITException( const std::string& msg, - std::optional python_class_name = std::nullopt, - std::optional original_msg = std::nullopt); + std::optional python_class_name = c10::nullopt, + std::optional original_msg = c10::nullopt); std::optional getPythonClassName() const { return python_class_name_; diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 2e609f18ecc074..dbc2638457c056 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -322,7 +322,7 @@ std::optional OperatorGenerator( torch::detail::SelectiveStr schema_str, Func&& op, AliasAnalysisKind alias_analysis) { - return std::nullopt; + return c10::nullopt; } template diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 54ec8e8441fa7e..48c7a1959ab220 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -36,7 +37,6 @@ #include #include #include -#include C10_DEFINE_bool( torch_jit_enable_new_executor, @@ -118,11 +118,11 @@ static FusionStrategy getInitialStrategy() { } // defer initial value so that we can load in gflags -static std::optional fusion_strategy = std::nullopt; +static std::optional fusion_strategy = c10::nullopt; FusionStrategy getFusionStrategy() { std::lock_guard guard(fusion_strategy_lock); - if (fusion_strategy == std::nullopt) { + if (fusion_strategy == c10::nullopt) { fusion_strategy = getInitialStrategy(); } return *fusion_strategy; @@ -130,7 +130,7 @@ FusionStrategy getFusionStrategy() { FusionStrategy setFusionStrategy(FusionStrategy& strategy) { std::lock_guard guard(fusion_strategy_lock); - if (fusion_strategy == std::nullopt) { + if (fusion_strategy == c10::nullopt) { fusion_strategy = getInitialStrategy(); } FusionStrategy old_strategy = *fusion_strategy; @@ -320,7 +320,7 @@ static bool guardDifferentiableGraph(Node* dnode) { // we inline the differentiable graph as a fallback // ideally we would set this up for re-profiling UpdateDifferentiableGraphRequiresGrad( - dnode->g(attr::Subgraph), std::nullopt); + dnode->g(attr::Subgraph), c10::nullopt); SubgraphUtils::unmergeSubgraph(dnode); return false; } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index ebdc5ba205cd56..3386bc3e4a4918 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -878,6 +878,6 @@ struct OperatorGeneratorArgs { TORCH_API at::Generator make_generator_for_device( c10::Device device, - std::optional seed = std::nullopt); + std::optional seed = c10::nullopt); } // namespace torch::jit diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f6eccede28bab1..bb9c08465c0ae9 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -7,7 +8,6 @@ #include #include #include -#include #include #include @@ -1807,7 +1807,7 @@ static const std::vector stringOpGenArgs{ std::string::size_type prev_pos = 0; std::string::size_type pos = 0; c10::List splits; - if (ivalue == std::nullopt) { + if (ivalue == c10::nullopt) { // if separator is not specified, // a different splitting algorithm is applied as Python splits = splitNoneSeparator(string); @@ -2463,8 +2463,8 @@ static const std::vector opGenArgs1{ // NOLINTNEXTLINE(cppcoreguidelines-init-variables) bool copy; pop(stack, self, non_blocking, copy); - std::optional device = std::nullopt; - std::optional scalarType = std::nullopt; + std::optional device = c10::nullopt; + std::optional scalarType = c10::nullopt; push( stack, to_dispatch(self, device, scalarType, non_blocking, copy)); }, diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 035a5d35c4630f..4359b852b6a38a 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -430,13 +430,13 @@ at::Tensor interpolate( std::optional align_corners, std::optional recompute_scale_factor) { if ((mode == "nearest" || mode == "area")) { - if (align_corners != std::nullopt) { + if (align_corners != c10::nullopt) { throw std::runtime_error( "align_corners option can only be set with the " "interpolating modes: linear | bilinear | bicubic | trilinear"); } } else { - if (align_corners == std::nullopt) { + if (align_corners == c10::nullopt) { TORCH_WARN( "Default upsampling behavior when mode=", mode, @@ -451,7 +451,7 @@ at::Tensor interpolate( double scale_factors_2 = -1.0; double scale_factors_3 = -1.0; - if (!scale_factors.isNone() && recompute_scale_factor == std::nullopt) { + if (!scale_factors.isNone() && recompute_scale_factor == c10::nullopt) { recompute_scale_factor = true; bool warn_recompute_scale_factor = false; @@ -510,7 +510,7 @@ at::Tensor interpolate( return at::upsample_nearest1d( input, _output_size(input, 1, size, scale_factors), - std::make_optional(scale_factors_1)); + c10::make_optional(scale_factors_1)); if (input_dim == dim2d && mode == "nearest") return at::upsample_nearest2d( input, @@ -538,7 +538,7 @@ at::Tensor interpolate( input, _output_size(input, 1, size, scale_factors), *align_corners, - std::make_optional(scale_factors_1)); + c10::make_optional(scale_factors_1)); if (input_dim == dim1d && mode == "bilinear") throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input"); if (input_dim == dim1d && mode == "bicubic") @@ -646,7 +646,7 @@ void upsample_nearest_op(Stack& stack) { pop(stack, input, size, scale_factor_int); IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int); at::Tensor res = interpolate( - input, size, scale_factor_double, "nearest", std::nullopt, std::nullopt); + input, size, scale_factor_double, "nearest", c10::nullopt, c10::nullopt); push(stack, std::move(res)); } @@ -664,7 +664,7 @@ void upsample_op(Stack& stack) { scale_factor_double, mode, align_corners.toOptional(), - std::nullopt); + c10::nullopt); push(stack, std::move(res)); } @@ -675,7 +675,7 @@ void upsample_bilinear_op(Stack& stack) { pop(stack, input, size, scale_factor_int); IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int); at::Tensor res = interpolate( - input, size, scale_factor_double, "bilinear", true, std::nullopt); + input, size, scale_factor_double, "bilinear", true, c10::nullopt); push(stack, std::move(res)); } diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 63fdee6de8042c..5b8c70c404ae98 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -301,9 +301,9 @@ RegisterOperators reg({ at::native::scalar_tensor( scalar_val, typeMetaToScalarType(c10::get_default_dtype()), - std::nullopt /* layout */, + c10::nullopt /* layout */, at::kCPU, - std::nullopt /* pin_memory*/)) + c10::nullopt /* pin_memory*/)) DEFINE_TORCH_TENSOR_OP( int, int64_t, @@ -314,9 +314,9 @@ RegisterOperators reg({ at::native::scalar_tensor( scalar_val, typeMetaToScalarType(c10::get_default_complex_dtype()), - std::nullopt /* layout */, + c10::nullopt /* layout */, at::kCPU, - std::nullopt /* pin_memory */)) + c10::nullopt /* pin_memory */)) // reference python implementation: internal_new_from_data in // tensor_new.cpp diff --git a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp index fd908b48ee043f..c1dbbddc6d337a 100644 --- a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp @@ -1,8 +1,8 @@ #include +#include #include #include -#include namespace torch::jit { diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index 86925200b7f463..ffac37efc9b765 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -173,7 +173,7 @@ static std::optional tryMerge( Node* to_merge, AliasDb* aliasDb) { if (!canMerge(fusion_group, to_merge, aliasDb)) { - return std::nullopt; + return c10::nullopt; } std::vector nodes_to_merge = {to_merge}; @@ -190,7 +190,7 @@ static std::optional tryMerge( GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n)); if (!aliasDb->moveBeforeTopologicallyValid(n, move_point)) { GRAPH_UPDATE("Failed to move because of AliasDb checks!"); - return std::nullopt; + return c10::nullopt; } move_point = n; } diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 0c989efcad7577..9dc31446d1e1c7 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -320,7 +320,7 @@ std::pair, std::optional> PrepareForStaticModule( const StaticModuleOptions& opts, std::vector sample_inputs) { PrepareGraphForStaticModule(graph, opts, std::move(sample_inputs)); - return std::make_pair(graph, std::nullopt); + return std::make_pair(graph, c10::nullopt); } } // namespace @@ -573,7 +573,7 @@ StaticModule::StaticModule( const auto num_schema_args = schema_->arguments().size(); DCHECK(num_schema_args > 0); if (removeSelfFromGraphInput(graph_)) { - module_ = std::nullopt; + module_ = c10::nullopt; num_inputs_ = num_schema_args - 1; } } @@ -1251,7 +1251,7 @@ bool BlockRunner::fast_check_and_correct_overlap_with( auto& tensor = tensor_ival.toTensor(); if (planner_->overlapWithInternalBuffer(tensor.data_ptr())) { DLOG(INFO) << "Detected alias for node: " << PrintNode(n.node()); - tensor_ival = at::native::clone(tensor, std::nullopt); + tensor_ival = at::native::clone(tensor, c10::nullopt); n.set_outputs_memory_overlap_detected(); return true; } @@ -2218,7 +2218,7 @@ bool ProcessedNode::check_and_correct_overlap_with( auto& tensor = output_ival.toTensor(); if (!checkNoMemoryOverlap(input, tensor)) { DLOG(INFO) << "Detected alias for node: " << PrintNode(node()); - output_ival = at::native::clone(tensor, std::nullopt); + output_ival = at::native::clone(tensor, c10::nullopt); set_outputs_memory_overlap_detected(); return true; } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 35a74c0bac089b..b1b8a081c4ce63 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -75,7 +75,7 @@ static void repeat_out( } // return an empty tensor if one of the repeat dimensions is zero - at::native::resize_(result, target_size, std::nullopt); + at::native::resize_(result, target_size, c10::nullopt); if (zero_tensor) { return; } @@ -101,7 +101,7 @@ at::Tensor& reshape_copy_out( const auto& shape = infer_size ? at::infer_size_dv(proposed_shape, self.numel()) : proposed_shape; - at::native::resize_(out, shape, std::nullopt); + at::native::resize_(out, shape, c10::nullopt); auto self_contig = self.expect_contiguous(); @@ -214,7 +214,7 @@ at::Tensor& to_copy_out( at::native::resize_impl_cpu_( out.unsafeGetTensorImpl(), self.sizes(), self.strides()); } else { - at::native::resize_(out, self.sizes(), std::nullopt); + at::native::resize_(out, self.sizes(), c10::nullopt); } auto is_unsupported_dtype = [](ScalarType t) { #define TORCH_OPS_UNSUPPORTED_TYPE(_, type) \ @@ -233,7 +233,7 @@ at::Tensor& to_copy_out( // expensive. if (self.is_contiguous() && !non_blocking && // Did the user request us to make a copy that isn't contiguous? - (memory_format == std::nullopt || + (memory_format == c10::nullopt || memory_format == c10::MemoryFormat::Preserve || memory_format == c10::MemoryFormat::Contiguous) && // CopyKernel.cpp handles this case specially, so let's not mess @@ -303,7 +303,7 @@ static Tensor& c2_argmin_out( out_dims.push_back(in_dims[i]); next_size *= in_dims[i]; } - at::native::resize_(output, out_dims, std::nullopt); + at::native::resize_(output, out_dims, c10::nullopt); const auto n = in_dims[dim_]; @@ -370,7 +370,7 @@ static at::Tensor& dequantize_copy_out(Tensor& out, const Tensor& self) { if (C10_UNLIKELY(!self.is_quantized())) { // fallback to dequantize_cpu equivalent case: make sure out is at::kFloat DCHECK(out.scalar_type() == kFloat); - return at::native::to_copy_out(out, self, false, false, std::nullopt); + return at::native::to_copy_out(out, self, false, false, c10::nullopt); } return get_qtensorimpl(self)->quantizer()->dequantize_out(out, self); } @@ -658,11 +658,11 @@ REGISTER_OPERATOR_FUNCTOR( out_t, at::cpu::clamp(in0_t, clamp_min, clamp_max), in3_s, - std::nullopt, - std::nullopt); + c10::nullopt, + c10::nullopt); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); auto output_size = in0_t.numel(); @@ -700,7 +700,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { at::cpu::clamp_out(out_t, in0_t, in1_s, in2_s); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); auto output_size = in0_t.numel(); auto min = in1_s.has_value() ? in1_s->toFloat() : -std::numeric_limits::infinity(); @@ -830,7 +830,7 @@ void varStackFastOut( ? std::array{num_inputs, 1} : std::array{1, num_inputs}; - at::native::resize_(out, output_size, std::nullopt); + at::native::resize_(out, output_size, c10::nullopt); AT_DISPATCH_ALL_TYPES(out.scalar_type(), "varStackFastOut", [&]() { auto* out_data = out.mutable_data_ptr(); @@ -952,7 +952,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { at::cpu::threshold_out(out_t, in0_t, 0, 0); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -975,7 +975,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator { at::cpu::tanh_out(out_t, in0_t); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -1036,7 +1036,7 @@ REGISTER_OPERATOR_FUNCTOR( at::cpu::sigmoid_out(out_t, in0_t); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); int64_t nn = in0_t.numel(); te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn}); }; @@ -1048,12 +1048,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { LogAndDumpSchema(n); return nullptr; } - std::optional clamp = std::nullopt; + std::optional clamp = c10::nullopt; if (n->inputs()[1]->node()->kind() == prim::Constant) { auto clamp_d = toIValue(n->inputs()[1])->toOptional(); clamp = clamp_d - ? std::make_optional(static_cast(clamp_d.value())) - : std::nullopt; + ? c10::make_optional(static_cast(clamp_d.value())) + : c10::nullopt; } auto te = clamp ? createLogit() : nullptr; float clamp_value = clamp ? *clamp : 0.0f; @@ -1070,7 +1070,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { at::native::logit_out(in0_t, in1_d, out_t); return; } - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); int64_t nn = in0_t.numel(); float c = clamp_value; te->call({out_t.data_ptr(), in0_t.data_ptr(), &nn, &c}); @@ -1454,7 +1454,7 @@ C10_ALWAYS_INLINE void to_copy_functor_impl( if (memory_format == c10::MemoryFormat::Preserve) { if (self.is_non_overlapping_and_dense()) { - memory_format = std::nullopt; + memory_format = c10::nullopt; copy_strides = true; } else { memory_format = self.suggest_memory_format(); @@ -1485,7 +1485,7 @@ C10_ALWAYS_INLINE void to_copy_functor_impl( args->dtype, args->layout, self.device(), - std::nullopt, + c10::nullopt, memory_format); } else { if (has_memory_format) { @@ -1905,7 +1905,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator { return [te = createDiv()](ProcessedNode* p_node) { const auto& in0_t = p_node->Input(0).toTensor(); - std::optional rounding_mode = std::nullopt; + std::optional rounding_mode = c10::nullopt; if (p_node->num_inputs() > 2) { rounding_mode = p_node->Input(2).toOptional(); } @@ -2112,14 +2112,14 @@ REGISTER_OPERATOR_FUNCTOR(aten::layer_norm, aten_layer_norm, [](Node* n) -> SROp if (p_node->Output(0).isNone()) { p_node->Output(0) = at::native::empty_like( *X, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, at::MemoryFormat::Contiguous); } else { at::native::resize_( - p_node->Output(0).toTensor(), X->sizes(), std::nullopt); + p_node->Output(0).toTensor(), X->sizes(), c10::nullopt); } at::Tensor& output = p_node->Output(0).toTensor(); at::native::layer_norm_cpu_out(output, *X, *gamma, *beta, eps, M, N); @@ -2231,12 +2231,12 @@ REGISTER_OPERATOR_FUNCTOR(quantized::linear, quantized_linear, [](Node* n) -> SR p_node->Output(0) = at::native::empty_affine_quantized( {0}, c10::kQUInt8, - std::nullopt, + c10::nullopt, c10::kCPU, false, output_scale, output_zero_point, - std::nullopt); + c10::nullopt); } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); @@ -2277,12 +2277,12 @@ REGISTER_OPERATOR_FUNCTOR( p_node->Output(0) = at::native::empty_affine_quantized( {0}, c10::kQUInt8, - std::nullopt, + c10::nullopt, c10::kCPU, false, output_scale, output_zero_point, - std::nullopt); + c10::nullopt); } auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); @@ -2463,7 +2463,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::full_like, aten_full_like, [](Node* n) -> SROper in0_t, dtype, layout, device, pin_memory, memory_format); } auto& out_t = p_node->Output(0).toTensor(); - at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::resize_(out_t, in0_t.sizes(), c10::nullopt); at::native::fill_out(out_t, in1_s); }; }); @@ -2528,7 +2528,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::zeros, aten_zeros, [](Node* n) -> SROperator { const auto layout = p_node->Input(2).toOptional(); if (!hasTensorWithOptions(p_node->Output(0), dtype, layout)) { p_node->Output(0) = at::compositeexplicitautograd::zeros( - size, dtype, layout, std::nullopt, std::nullopt); + size, dtype, layout, c10::nullopt, c10::nullopt); return; } auto& out_t = p_node->Output(0).toTensor(); @@ -2709,7 +2709,7 @@ unsigned char abs_if_signed(unsigned char val) { // Computes f(x) = sign(x) * ln(|1 + x|) for each x in the input tensor void signed_log1p_out(at::Tensor& out, const at::Tensor& input) { - at::native::resize_(out, input.sizes(), std::nullopt); + at::native::resize_(out, input.sizes(), c10::nullopt); const auto input_contig = input.expect_contiguous(); auto output_contig = out.expect_contiguous(); @@ -2750,7 +2750,7 @@ REGISTER_OPERATOR_FUNCTOR( signed_log1p_out(out, input); return; } - at::native::resize_(out, input.sizes(), std::nullopt); + at::native::resize_(out, input.sizes(), c10::nullopt); int64_t nn = input.numel(); te->call({out.data_ptr(), input.data_ptr(), &nn}); }; diff --git a/torch/csrc/jit/runtime/static/ops.h b/torch/csrc/jit/runtime/static/ops.h index 623340daec068b..362837e7ce78f0 100644 --- a/torch/csrc/jit/runtime/static/ops.h +++ b/torch/csrc/jit/runtime/static/ops.h @@ -57,8 +57,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t) { c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - std::nullopt, - std::nullopt); + c10::nullopt, + c10::nullopt); } inline at::Tensor create_empty_from( @@ -69,20 +69,20 @@ inline at::Tensor create_empty_from( c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - std::nullopt, - std::nullopt); + c10::nullopt, + c10::nullopt); } inline at::Tensor create_empty(c10::ScalarType dtype) { return at::detail::empty_cpu( - {0}, dtype, std::nullopt, std::nullopt, std::nullopt, std::nullopt); + {0}, dtype, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); } inline at::Tensor create_empty_from( const at::Tensor& t, c10::ScalarType dtype) { return at::detail::empty_cpu( - {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt); + {0}, dtype, t.layout(), t.device(), c10::nullopt, c10::nullopt); } inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { @@ -91,8 +91,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { c10::typeMetaToScalarType(t.dtype()), layout, t.device(), - std::nullopt, - std::nullopt); + c10::nullopt, + c10::nullopt); } inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { @@ -101,8 +101,8 @@ inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { c10::typeMetaToScalarType(t.dtype()), t.layout(), device, - std::nullopt, - std::nullopt); + c10::nullopt, + c10::nullopt); } inline at::Tensor create_empty_from( @@ -113,7 +113,7 @@ inline at::Tensor create_empty_from( c10::typeMetaToScalarType(t.dtype()), t.layout(), t.device(), - std::nullopt, + c10::nullopt, memory_format); } @@ -122,7 +122,7 @@ inline at::Tensor create_empty_from( c10::ScalarType dtype, c10::MemoryFormat memory_format) { return at::detail::empty_cpu( - {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format); + {0}, dtype, t.layout(), t.device(), c10::nullopt, memory_format); } inline bool checkResizedDataPtr(at::Tensor& t) { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 92d901e43a5d21..6aa65c528a42b8 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1609,7 +1609,7 @@ static void loadModule(const CompilationUnit& module) { static void loadFunctions() { for (const std::string& str : functions) { - compilation_unit.define(std::nullopt, str, nativeResolver(), nullptr); + compilation_unit.define(c10::nullopt, str, nativeResolver(), nullptr); } loadModule(compilation_unit); } @@ -1635,7 +1635,7 @@ std::optional gradientInfoForSchema( return sym_script_it->second; } } - return std::nullopt; + return c10::nullopt; } bool hasGradientInfoForSchema(const FunctionSchema& schema) { diff --git a/torch/csrc/jit/runtime/symbolic_script.h b/torch/csrc/jit/runtime/symbolic_script.h index 0715f0deeb1208..271bf66916f3d6 100644 --- a/torch/csrc/jit/runtime/symbolic_script.h +++ b/torch/csrc/jit/runtime/symbolic_script.h @@ -2,9 +2,9 @@ // This file is temporary until native_functions.yaml and derivatives.yaml are // merged. Ideally this should all go into native_functions.yaml +#include #include #include -#include namespace torch::jit { struct GradientPair { diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index f8cfca26c702a6..ddea031aba73c8 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -365,7 +365,7 @@ void loadFunctions() { [&](const std::string& name) -> std::shared_ptr { return src; }, 1); compilation_unit->define( - std::nullopt, shape_compute_functions, resolver, nullptr); + c10::nullopt, shape_compute_functions, resolver, nullptr); loadModule(*compilation_unit); } catch (...) { // Reset the cache and compilation unit so that we don't get weird errors @@ -391,7 +391,7 @@ std::optional> shapeComputeGraphForSchema( } GRAPH_DEBUG("Could not find schema: ", schema); - return std::nullopt; + return c10::nullopt; } TORCH_API std::optional boundedGraphsForSchema( @@ -406,7 +406,7 @@ TORCH_API std::optional boundedGraphsForSchema( return cache_it->second; } - return std::nullopt; + return c10::nullopt; } void RegisterShapeComputeGraphForSchema( diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 2bc464a0de172c..4a326285b29740 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -173,7 +173,7 @@ std::optional InlinedCallStackDeserializer:: const c10::IValue& iv, const std::shared_ptr& cu) { if (iv.isNone()) { - return std::nullopt; + return c10::nullopt; } auto tup = iv.toTuple(); auto it = cached_module_instance_info_.find(tup); diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 2cfe34cd4abd2a..6ef9bdbf4abfac 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -20,7 +21,6 @@ #include #include #include -#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof") #include diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index 4b895f9d657b44..9f194cd0ad31b7 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -166,7 +166,7 @@ mobile::Code compileGraphToMobileCode( // and is not allowed. For an operator with num_args = -1, it means the // number of arguments is not available for this operator, we don't do any // backward compatibility adaptation at runtime. - std::optional num_args = std::nullopt; + std::optional num_args = c10::nullopt; auto it = op_to_specified_args.find(unique_name); if (it != op_to_specified_args.end()) { num_args = it->second; diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 779e63a8436092..5bd7714c4e8d20 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -259,7 +259,7 @@ std::pair getFunctionTuple( if (namedType && namedType->name()) { return type_name_uniquer_.getUniqueName(namedType).qualifiedName(); } - return std::nullopt; + return c10::nullopt; }; auto makeArgTuple = [&](const std::vector& args) { @@ -765,7 +765,7 @@ std::optional type_printer( if (namedType && namedType->name()) { return type_name_uniquer.getUniqueName(namedType).qualifiedName(); } - return std::nullopt; + return c10::nullopt; } } // namespace diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index e1ad60afa5c387..5a47fe900f3fdc 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -69,7 +69,7 @@ auto print_type(const c10::Type& t) -> std::optional { if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return std::nullopt; + return c10::nullopt; } class FlatbufferSerializer { @@ -306,7 +306,7 @@ flatbuffers::Offset FlatbufferSerializer:: if (auto dyn = t.castRaw()) { return dyn->fallback()->annotation_str(); } - return std::nullopt; + return c10::nullopt; }; flatbuffers::Offset schema_offset = 0; diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index 2da1e639ee80a7..b090a1c80a3cd4 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -21,19 +21,19 @@ class DeserializationStorageContext; TORCH_API Module import_ir_module( std::shared_ptr cu, const std::string& filename, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( std::shared_ptr cu, std::istream& in, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( std::shared_ptr cu, std::unique_ptr rai, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module import_ir_module( @@ -80,7 +80,7 @@ TORCH_API Module import_ir_module( /// `torch::jit::ExportModule` in C++. TORCH_API Module load( std::istream& in, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -96,7 +96,7 @@ TORCH_API Module load( /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( const std::string& filename, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -112,7 +112,7 @@ TORCH_API Module load( /// Python or `torch::jit::ExportModule` in C++. TORCH_API Module load( std::shared_ptr rai, - std::optional device = std::nullopt, + std::optional device = c10::nullopt, bool load_debug_files = true); TORCH_API Module load( @@ -131,17 +131,17 @@ TORCH_API Module parse_and_initialize_jit_module( std::shared_ptr data, size_t size, ExtraFilesMap& extra_files, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); TORCH_API Module load_jit_module_from_file( const std::string& filename, ExtraFilesMap& extra_files, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); TORCH_API Module load_jit_module_from_stream( std::istream& in, ExtraFilesMap& extra_files, - std::optional device = std::nullopt); + std::optional device = c10::nullopt); TORCH_API Module parse_and_initialize_jit_module( std::shared_ptr data, diff --git a/torch/csrc/jit/serialization/import_source.cpp b/torch/csrc/jit/serialization/import_source.cpp index 017ae5bd3da7cf..f67c2a22e9eb13 100644 --- a/torch/csrc/jit/serialization/import_source.cpp +++ b/torch/csrc/jit/serialization/import_source.cpp @@ -372,7 +372,7 @@ std::optional SourceImporterImpl:: if (replacements.count(demangled_classname)) { auto lhs = Var(assign.lhs()); if (!assign.type().present() || assign.type().get().kind() != TK_VAR) { - return std::nullopt; + return c10::nullopt; } auto type = Var(assign.type().get()); @@ -389,7 +389,7 @@ std::optional SourceImporterImpl:: assign.range(), assign.lhs_list(), assign.rhs(), maybe_typename); } } - return std::nullopt; + return c10::nullopt; } void SourceImporterImpl::importClass( diff --git a/torch/csrc/jit/serialization/import_source.h b/torch/csrc/jit/serialization/import_source.h index a86a1f91926df7..9b364f379b4091 100644 --- a/torch/csrc/jit/serialization/import_source.h +++ b/torch/csrc/jit/serialization/import_source.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -12,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -66,7 +66,7 @@ struct SourceImporterImpl : public Resolver, std::shared_ptr cu_; std::unordered_map> env_; SourceLoader source_loader_; - std::optional version_ = std::nullopt; + std::optional version_ = c10::nullopt; std::unordered_set loaded_sources_; // named types and functions loaded from a file but not yet defined because // their type has not been requested yet. diff --git a/torch/csrc/jit/serialization/pickle.cpp b/torch/csrc/jit/serialization/pickle.cpp index c05bf330e7af3c..be36a4e2d8dd5e 100644 --- a/torch/csrc/jit/serialization/pickle.cpp +++ b/torch/csrc/jit/serialization/pickle.cpp @@ -92,9 +92,9 @@ IValue pickle_load(const std::vector& data) { "data", /*pickle_prefix=*/"", /*tensor_prefix=*/"", - /*type_resolver=*/std::nullopt, - /*obj_loader=*/std::nullopt, - /*device=*/std::nullopt, + /*type_resolver=*/c10::nullopt, + /*obj_loader=*/c10::nullopt, + /*device=*/c10::nullopt, reader); #else AT_ERROR( diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 04d3fc9a435614..173ab5c13e5da4 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -605,7 +605,7 @@ std::optional type_printer(const c10::Type& type) { if (auto dyn = type.castRaw()) { return dyn->fallback()->annotation_str(type_printer); } - return std::nullopt; + return c10::nullopt; } } // namespace diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 2292f11fd555ea..f1b0865032c392 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1725,7 +1725,7 @@ static std::optional printType( if (namedType && namedType->name()) { return type_name_uniquer.getUniqueName(namedType).qualifiedName(); } - return std::nullopt; + return c10::nullopt; } void jitModuleToPythonCodeAndConstants( diff --git a/torch/csrc/jit/serialization/source_range_serialization.cpp b/torch/csrc/jit/serialization/source_range_serialization.cpp index 6892493312b002..118becd20dc7c6 100644 --- a/torch/csrc/jit/serialization/source_range_serialization.cpp +++ b/torch/csrc/jit/serialization/source_range_serialization.cpp @@ -68,7 +68,7 @@ std::shared_ptr SourceRangeDeserializer::deserialize_source( const auto& textIndex = tup_elems[0].toIntList(); int64_t fnameIndex = tup_elems[1].toInt(); int64_t starting_line_no_ = tup_elems[2].toInt(); - std::optional filename = std::nullopt; + std::optional filename = c10::nullopt; TORCH_CHECK( (uint64_t)fnameIndex < text_table_.size(), @@ -248,7 +248,7 @@ std::optional ConcreteSourceRangeUnpickler:: return (entry - 1)->range; } - return std::nullopt; + return c10::nullopt; } TORCH_API void setShouldUseFormatWithStringTable( diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index 1ba4d54c4d29ca..e1464d0efc3ec0 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -99,7 +99,7 @@ static std::optional bufSize(BufPtr buf) { size_t size = elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes(); for (auto& d : buf->dims()) { if (!d->isConstant()) { - return std::nullopt; + return c10::nullopt; } size = size * (*intValue(d)); } diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index ceab479dc87946..5666097f2dd45b 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1305,7 +1305,7 @@ std::optional evalInt(ExprPtr e) { return ExprEval(cast(ExprHandle(e))) .value(); } catch (std::runtime_error& err) { - return std::nullopt; + return c10::nullopt; } } diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index c410c902ea4e4e..8c8de89975750c 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -6,11 +6,11 @@ #pragma once #include +#include #include #include #include #include -#include #include @@ -207,10 +207,10 @@ class TORCH_API Buf : public ExprNode { const std::string& name_hint, const std::vector& dims, Dtype dtype, - std::optional initializer = std::nullopt, - std::optional> strides = std::nullopt, - std::optional qscale = std::nullopt, - std::optional qzero = std::nullopt); + std::optional initializer = c10::nullopt, + std::optional> strides = c10::nullopt, + std::optional qscale = c10::nullopt, + std::optional qzero = c10::nullopt); // TODO: unique_name VarPtr base_handle() const { @@ -232,7 +232,7 @@ class TORCH_API Buf : public ExprNode { const std::vector& dims, Dtype dtype, ExprPtr initializer = nullptr, - std::optional> strides = std::nullopt, + std::optional> strides = c10::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr) : Buf(alloc(name_hint, kHandle), @@ -248,7 +248,7 @@ class TORCH_API Buf : public ExprNode { std::vector dims, Dtype dtype, ExprPtr initializer = nullptr, - std::optional> strides = std::nullopt, + std::optional> strides = c10::nullopt, ExprPtr qscale = nullptr, ExprPtr qzero = nullptr); diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index decfe0bceb3215..a3146ccfaff550 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -123,7 +123,7 @@ std::vector constructTensors( } } else { // handle quantized - std::vector> qdata(bufs_num, std::nullopt); + std::vector> qdata(bufs_num, c10::nullopt); for (const auto& qd : *qdataArg) { qdata[qd.first] = qd.second; } @@ -233,7 +233,7 @@ std::vector constructTensors2( } } else { // handle quantized - std::vector> qdata(bufs_in_num, std::nullopt); + std::vector> qdata(bufs_in_num, c10::nullopt); for (const auto& qd : *qdataArg) { qdata[qd.first - bufs_out_num] = qd.second; } @@ -993,10 +993,10 @@ void nnc_aten_upsample_nearest2d( x, (output_size_h != -1) ? std::optional({output_size_h, output_size_w}) - : std::nullopt, + : c10::nullopt, (scale_factor_h != -1.f) ? std::optional>( {scale_factor_h, scale_factor_w}) - : std::nullopt); + : c10::nullopt); memcpy(buf_data[0], r.const_data_ptr(), r.element_size() * r.numel()); } @@ -1043,10 +1043,10 @@ void nnc_aten_upsample_nearest2d_out( x, (output_size_h != -1) ? std::optional({output_size_h, output_size_w}) - : std::nullopt, + : c10::nullopt, (scale_factor_h != -1.f) ? std::optional>( {scale_factor_h, scale_factor_w}) - : std::nullopt); + : c10::nullopt); buf_data[0] = r.data_ptr(); c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get()); buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get(); @@ -1089,7 +1089,7 @@ void nnc_aten_quantize_per_tensor_out( buf_dims, buf_strides, buf_dtypes, - std::nullopt, + c10::nullopt, bufs_out_num); // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) at::Tensor x = tensors[1]; @@ -1214,7 +1214,7 @@ void nnc_aten_conv1d_out( buf_dims, buf_strides, buf_dtypes, - std::nullopt, + c10::nullopt, bufs_out_num); at::Tensor r; diff --git a/torch/csrc/jit/tensorexpr/external_functions.h b/torch/csrc/jit/tensorexpr/external_functions.h index 9dc859d2247158..1fd90a3f056b8a 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.h +++ b/torch/csrc/jit/tensorexpr/external_functions.h @@ -75,7 +75,7 @@ std::vector constructTensors( int64_t* buf_strides, int8_t* buf_dtypes, std::optional>> qdataArg = - std::nullopt); + c10::nullopt); std::vector constructTensors2( int64_t bufs_in_num, @@ -85,7 +85,7 @@ std::vector constructTensors2( int64_t* buf_strides, int8_t* buf_dtypes, std::optional>> qdataArg = - std::nullopt, + c10::nullopt, size_t bufs_out_num = 0); #ifdef C10_MOBILE diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp index 0699dfd63da543..01511b2b4d8c5c 100644 --- a/torch/csrc/jit/tensorexpr/graph_opt.cpp +++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp @@ -351,7 +351,7 @@ static std::optional inferScalarType(Node* n) { if (tt->scalarType() && *tt->scalarType() != scalar_type) { GRAPH_DEBUG( "Inputs of ", n, " have different scalar types, cannot fixup!"); - return std::nullopt; + return c10::nullopt; } } } @@ -369,7 +369,7 @@ static std::optional inferDevice(Node* n) { } if (tt->device() && *tt->device() != device) { GRAPH_DEBUG("Inputs of ", n, " have different devices, cannot fixup!"); - return std::nullopt; + return c10::nullopt; } } } diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h index 90c5400472514a..89c3f96aba6e3a 100644 --- a/torch/csrc/jit/tensorexpr/ir.h +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -367,7 +367,7 @@ inline std::optional intValue(const ExprPtr& e) { } AT_FORALL_INT_TYPES(TYPE_CASE); #undef TYPE_CASE - return std::nullopt; + return c10::nullopt; } inline std::optional intValue(const ExprHandle& e) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index b69d167dba535f..afb7aefdda652f 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -1885,7 +1885,7 @@ static std::optional isModRound(TermPtr e) { if (!mod) { mod = to(m); } else { - return std::nullopt; + return c10::nullopt; } } else { // Take care of special cases before multiplying the scalar and variable. @@ -1911,14 +1911,14 @@ static std::optional isModRound(TermPtr e) { if (!mod) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - return std::nullopt; + return c10::nullopt; } mod_divisor = IRSimplifier::simplify(mod->rhs()); other = mod->lhs(); if (!(div = to
(other))) { - return std::nullopt; + return c10::nullopt; } divisor = IRSimplifier::simplify(div->rhs()); @@ -1953,16 +1953,16 @@ static std::optional isModRound(TermPtr e) { // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) denom = IRSimplifier::simplify(alloc
(other, c)); } else { - return std::nullopt; + return c10::nullopt; } } else { - return std::nullopt; + return c10::nullopt; } } // Deny cases in which divisor=1. Such cases are considered as Mods. if (divisor->isConstant() && immediateEquals(divisor, 1)) { - return std::nullopt; + return c10::nullopt; } if (!scalar) { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 81c171d5671175..d18a3d65f21ed0 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -129,12 +129,12 @@ bool& getOptConditionals() { std::optional pickDeviceType( const at::ArrayRef& inputs) { - std::optional device = std::nullopt; + std::optional device = c10::nullopt; for (auto const& input : inputs) { auto tt = input->type()->cast(); if (tt && tt->device()) { if (device && *device != *tt->device()) { - return std::nullopt; + return c10::nullopt; } device = *tt->device(); } @@ -144,7 +144,7 @@ std::optional pickDeviceType( static std::optional pickDeviceType( const std::shared_ptr& graph) { - std::optional device = std::nullopt; + std::optional device = c10::nullopt; for (auto const& node : graph->nodes()) { for (auto const& input : node->inputs()) { if (auto tt = input->type()->cast()) { @@ -184,10 +184,10 @@ static std::optional getTensorInfoJit(torch::jit::Value* v) { c10::ScalarType dtype = c10::ScalarType::Float; if (!it) { - return std::nullopt; + return c10::nullopt; } if (!it->isComplete()) { - return std::nullopt; + return c10::nullopt; } if (it->scalarType()) { // TODO: ideally we should be strict here and return nullopt if the dtype is @@ -197,7 +197,7 @@ static std::optional getTensorInfoJit(torch::jit::Value* v) { } auto concrete_sizes = it->sizes().concrete_sizes(); if (!concrete_sizes) { - return std::nullopt; + return c10::nullopt; } return TensorInfo{*concrete_sizes, dtype}; } @@ -712,7 +712,7 @@ static std::optional tripCount(ForPtr loop) { if (auto val = to(tc.node())) { return val->value(); } - return std::nullopt; + return c10::nullopt; } // Prune innermost loops until iterations satisfies a minimum grain size. @@ -1314,7 +1314,7 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides( BufPtr buf = bufs_.at(v); TORCH_INTERNAL_ASSERT(buf != nullptr); TORCH_INTERNAL_ASSERT(tt != nullptr); - TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != std::nullopt); + TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != c10::nullopt); auto stride_desc = getSymbolicStrideDesc(v); TORCH_INTERNAL_ASSERT(stride_desc.size() == 1); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 1cae1fe9b2dc22..dec03637847e29 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -85,15 +85,15 @@ C10_DEFINE_bool( namespace torch::jit::tensorexpr { std::optional& LLVMTargetTriple() { - static std::optional triple = std::nullopt; + static std::optional triple = c10::nullopt; return triple; } std::optional& LLVMTargetCPU() { - static std::optional cpu = std::nullopt; + static std::optional cpu = c10::nullopt; return cpu; } std::optional& LLVMTargetAttrs() { - static std::optional attrs = std::nullopt; + static std::optional attrs = c10::nullopt; return attrs; } bool& LLVMAOTWorkflow() { diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h index 1d96b4dd0467e3..74271fa879f3de 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.h +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -7,7 +7,7 @@ #include #include -#include +#include #include #include @@ -27,9 +27,9 @@ class TORCH_API LLVMCodeGen : public CodeGen { at::Device device = at::kCPU, const std::string& kernel_func_name = "func", Dtype dtype = kInt, - std::optional triple = std::nullopt, - std::optional cpu = std::nullopt, - std::optional attrs = std::nullopt); + std::optional triple = c10::nullopt, + std::optional cpu = c10::nullopt, + std::optional attrs = c10::nullopt); explicit LLVMCodeGen(StmtPtr stmt); LLVMCodeGen() = delete; @@ -126,9 +126,9 @@ struct TORCH_API LLVMCodeGenBuilder { at::Device device_ = at::kCPU; std::string kernelFuncName_ = "func"; Dtype dtype_ = kInt; - std::optional triple_ = std::nullopt; - std::optional cpu_ = std::nullopt; - std::optional attrs_ = std::nullopt; + std::optional triple_ = c10::nullopt; + std::optional cpu_ = c10::nullopt; + std::optional attrs_ = c10::nullopt; }; TORCH_API std::optional& LLVMTargetTriple(); diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index beadbdd5e537e7..98238e0043885f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -3,8 +3,8 @@ #ifdef TORCH_ENABLE_LLVM #include #include +#include #include -#include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index bfce006d55177e..bdf313f0ad0515 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -51,7 +51,7 @@ Tensor conv2d_depthwise_static( Tensor conv = Reduce( "conv2d_depthwise", {N, K, OH, OW}, - std::nullopt, // TODO + c10::nullopt, // TODO Sum(), [&](const std::vector& v) { return init_func(v); }, [&](const std::vector& v) { @@ -123,7 +123,7 @@ Tensor conv2d_depthwise_dynamic( return Reduce( "conv2d_depthwise", {N, K, OH, OW}, - std::nullopt, // TODO + c10::nullopt, // TODO Sum(), [&](const std::vector& v) { return init_func(v); }, [&](const std::vector& v) { diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index 6ff6dd733885be..938cab6ffd8830 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -165,7 +165,7 @@ std::optional getTensorInfo(BufHandle b) { for (auto dim : b.dims()) { auto val = intValue(dim.node()); if (!val) { - return std::nullopt; + return c10::nullopt; } dims.push_back(*val); } diff --git a/torch/csrc/jit/tensorexpr/operators/pointwise.h b/torch/csrc/jit/tensorexpr/operators/pointwise.h index 589674117c1a00..0ce10424b3d30a 100644 --- a/torch/csrc/jit/tensorexpr/operators/pointwise.h +++ b/torch/csrc/jit/tensorexpr/operators/pointwise.h @@ -9,7 +9,7 @@ namespace tensorexpr { TORCH_API Tensor computeSign( const std::vector& inputs, const std::vector& outputShape, - std::optional> outputStrides = std::nullopt); + std::optional> outputStrides = c10::nullopt); Tensor computeOneOperand( const std::string& name, diff --git a/torch/csrc/jit/tensorexpr/operators/quantization.cpp b/torch/csrc/jit/tensorexpr/operators/quantization.cpp index 204a4c2211f7a9..66c0688538a1d7 100644 --- a/torch/csrc/jit/tensorexpr/operators/quantization.cpp +++ b/torch/csrc/jit/tensorexpr/operators/quantization.cpp @@ -171,7 +171,7 @@ Tensor computeQuantizePerTensor( ExprHandleVectorToExprVector(outputShape), dtype, nullptr, - std::nullopt, + c10::nullopt, qscale.node(), qzero.node()); return Tensor(buf, vars, e.node()); @@ -731,7 +731,7 @@ Tensor computeUpsampleNearest2d( "upsample_nearest2d", outputShape, Dtype(*outputType), - std::nullopt, // initializer + c10::nullopt, // initializer fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); }), ExprHandle(A.node()->qscale()), ExprHandle(A.node()->qzero())); diff --git a/torch/csrc/jit/tensorexpr/operators/softmax.cpp b/torch/csrc/jit/tensorexpr/operators/softmax.cpp index f73e06086d3d9e..9bd82afd177d46 100644 --- a/torch/csrc/jit/tensorexpr/operators/softmax.cpp +++ b/torch/csrc/jit/tensorexpr/operators/softmax.cpp @@ -103,7 +103,7 @@ Tensor computeSoftmax( auto max = Reduce( "aten_softmax_max", non_softmax_dims, - std::nullopt, + c10::nullopt, Maximum(dtype), [&](ParameterList& indices) { return tensorOrConstant( @@ -113,7 +113,7 @@ Tensor computeSoftmax( auto e = Compute( "aten_softmax_exp", outputShape, - std::nullopt, + c10::nullopt, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); @@ -122,7 +122,7 @@ Tensor computeSoftmax( auto sum = Reduce( "aten_softmax_sum", non_softmax_dims, - std::nullopt, + c10::nullopt, Sum(), [&](ParameterList& indices) { return e.load(move_softmax_dim_index_to_pos(indices)); @@ -130,7 +130,7 @@ Tensor computeSoftmax( {outputShape[softmax_dim]}); if (!log_softmax) { auto result = Compute( - "aten_softmax", outputShape, std::nullopt, [&](ParameterList& indices) { + "aten_softmax", outputShape, c10::nullopt, [&](ParameterList& indices) { return e.load(indices) / sum.load(remove_softmax_dim_index(indices)); }); return Tensor( @@ -142,12 +142,12 @@ Tensor computeSoftmax( auto log_sum = Compute( "aten_softmax_log_sum", non_softmax_dims, - std::nullopt, + c10::nullopt, [&](ParameterList& indices) { return log(sum.load(indices)); }); auto result = Compute( "aten_log_softmax", outputShape, - std::nullopt, + c10::nullopt, [&](ParameterList& indices) { auto inp = tensorOrConstant( inputs[0], convert_indices_to_expr_handle(indices)); diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp index 5a9af09f9d87eb..5bc734bb80b838 100644 --- a/torch/csrc/jit/tensorexpr/tensor.cpp +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -103,14 +103,14 @@ Tensor Compute( const std::function&)>& body_func) { std::vector args = create_index_vars(dims); ExprHandle body = body_func(args); - BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function&)>& body_func) { - return Compute(name, dims, std::nullopt, body_func); + return Compute(name, dims, c10::nullopt, body_func); } Tensor Compute( @@ -124,14 +124,14 @@ Tensor Compute( std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0]); - BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( const std::string& name, const std::vector& dims, const std::function& body_func) { - return Compute(name, dims, std::nullopt, body_func); + return Compute(name, dims, c10::nullopt, body_func); } Tensor Compute( @@ -145,7 +145,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1]); - BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -153,7 +153,7 @@ Tensor Compute( const std::vector& dims, const std::function& body_func) { - return Compute(name, dims, std::nullopt, body_func); + return Compute(name, dims, c10::nullopt, body_func); } Tensor Compute( @@ -168,7 +168,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2]); - BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -177,7 +177,7 @@ Tensor Compute( const std::function< ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { - return Compute(name, dims, std::nullopt, body_func); + return Compute(name, dims, c10::nullopt, body_func); } Tensor Compute( @@ -194,7 +194,7 @@ Tensor Compute( } std::vector args = create_index_vars(dims); ExprHandle body = body_func(args[0], args[1], args[2], args[3]); - BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides); + BufHandle buf = Buf::make(name, dims, body.dtype(), c10::nullopt, strides); return Tensor(buf, args, body); } Tensor Compute( @@ -205,7 +205,7 @@ Tensor Compute( const VarHandle&, const VarHandle&, const VarHandle&)>& body_func) { - return Compute(name, dims, std::nullopt, body_func); + return Compute(name, dims, c10::nullopt, body_func); } Tensor Reduce( @@ -229,7 +229,7 @@ Tensor Reduce( const Reducer& reducer, const BufHandle& buffer, const std::vector& reduce_dims) { - return Reduce(name, dims, std::nullopt, reducer, buffer, reduce_dims); + return Reduce(name, dims, c10::nullopt, reducer, buffer, reduce_dims); } Tensor Reduce( @@ -253,7 +253,7 @@ Tensor Reduce( const Reducer& reducer, Tensor tensor, const std::vector& reduce_dims) { - return Reduce(name, dims, std::nullopt, reducer, tensor, reduce_dims); + return Reduce(name, dims, c10::nullopt, reducer, tensor, reduce_dims); } } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index 3fb55152b70d64..7b589d0974b37b 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -161,7 +161,7 @@ Tensor Reduce( if (reduce_vars.empty()) { ExprHandle body = Reducer::getReduceBody(body_func, vars); BufHandle func_result = Buf::make( - func_name, dims, body.dtype(), std::nullopt, std::move(strides)); + func_name, dims, body.dtype(), c10::nullopt, std::move(strides)); return Tensor(std::move(func_result), vars, std::move(body)); } @@ -206,7 +206,7 @@ Tensor Reduce( return Reduce( func_name, dims, - std::nullopt, + c10::nullopt, reducer, init_func, body_func, @@ -238,7 +238,7 @@ Tensor Reduce( const BodyFunc& body_func, const std::vector& reduce_dims) { return Reduce( - func_name, dims, std::nullopt, reducer, body_func, reduce_dims); + func_name, dims, c10::nullopt, reducer, body_func, reduce_dims); } // Overload which allows inline lambda functions for the body_func. @@ -259,7 +259,7 @@ Tensor Reduce( const Reducer& reducer, const BodyFunc&& body_func, const std::vector& reduce_dims) { - return Reduce(func_name, dims, std::nullopt, reducer, body_func, reduce_dims); + return Reduce(func_name, dims, c10::nullopt, reducer, body_func, reduce_dims); } TORCH_API Tensor Reduce( diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 027eb2aa0acf6b..ec0011f40d775c 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -10,13 +10,13 @@ // API modified from llvm::FileCheck #include +#include #include #include #include #include #include #include -#include #include #include @@ -43,13 +43,13 @@ struct Check { Check( CheckType type, std::string str, - std::optional count = std::nullopt) + std::optional count = c10::nullopt) : type_(type), count_(count), search_str_(std::move(str)) {} Check( CheckType type, c10::string_view str, - std::optional count = std::nullopt) + std::optional count = c10::nullopt) : Check(type, std::string(str.begin(), str.end()), count) {} CheckType type_; @@ -234,7 +234,7 @@ struct FileCheckImpl { TORCH_API void addCheck( CheckType type, const std::string& s, - std::optional count = std::nullopt) { + std::optional count = c10::nullopt) { addCheck(Check(type, s, count)); } @@ -264,7 +264,7 @@ struct FileCheckImpl { } size_t end_check_string = suffix_pos + check_suffix.size(); CheckType type = check_pair.first; - std::optional count = std::nullopt; + std::optional count = c10::nullopt; auto end_line = source->text_str().find("\n", end_check_string); bool exactly = false; if (type == CHECK_COUNT) { diff --git a/torch/csrc/lazy/backend/backend_device.cpp b/torch/csrc/lazy/backend/backend_device.cpp index 3eac703be175fc..6d146ca0881ceb 100644 --- a/torch/csrc/lazy/backend/backend_device.cpp +++ b/torch/csrc/lazy/backend/backend_device.cpp @@ -2,10 +2,10 @@ #include #include +#include #include #include #include -#include namespace torch { namespace lazy { @@ -60,7 +60,7 @@ std::optional GetBackendDevice(at::ITensorListRef tensors) { return lt->GetDevice(); } } - return std::nullopt; + return c10::nullopt; } std::optional GetBackendDevice(at::TensorList tensors) { @@ -71,19 +71,19 @@ std::optional GetBackendDevice(const at::Tensor& tensor) { if (auto lt = TryGetLtcTensor(tensor)) { return lt->GetDevice(); } - return std::nullopt; + return c10::nullopt; } std::optional GetBackendDevice( const std::optional& device) { if (device) { - return std::make_optional(atenDeviceToBackendDevice(*device)); + return c10::make_optional(atenDeviceToBackendDevice(*device)); } - return std::nullopt; + return c10::nullopt; } std::optional GetBackendDevice() { - return std::nullopt; + return c10::nullopt; } } // namespace lazy diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index fdfc2ac15d9a89..e80c800a2ecead 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include namespace c10 { struct Device; diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 570dc942e6a68a..981e1667772944 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -1,12 +1,12 @@ #pragma once #include +#include #include #include #include #include #include -#include #include // This file is part of the backend interface. So, ops shouldn't be added or @@ -61,7 +61,7 @@ struct IrBuilder { virtual NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, - const std::optional& stype = std::nullopt) const = 0; + const std::optional& stype = c10::nullopt) const = 0; virtual NodePtr MakeTensorList(const OpList& inputs) const = 0; virtual NodePtr MakeGeneric( const OpKind& op, @@ -96,7 +96,7 @@ static inline NodePtr MakeExpand( static inline NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, - const std::optional& stype = std::nullopt) { + const std::optional& stype = c10::nullopt) { return getIrBuilder()->MakeCast(input0, dtype, stype); } static inline NodePtr MakeTensorList(const OpList& inputs) { diff --git a/torch/csrc/lazy/core/ir_dump_util.cpp b/torch/csrc/lazy/core/ir_dump_util.cpp index d81d810a54e98f..a4fb11761a67ce 100644 --- a/torch/csrc/lazy/core/ir_dump_util.cpp +++ b/torch/csrc/lazy/core/ir_dump_util.cpp @@ -1,10 +1,10 @@ #include +#include #include #include #include #include -#include #include #include @@ -37,7 +37,7 @@ std::optional ParseAttrTag( // @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful if (!std::regex_search( node_string.begin() + pos, node_string.end(), match, tag_regex)) { - return std::nullopt; + return c10::nullopt; } std::string::size_type vpos = match[1].second - node_string.begin() + 1; @@ -102,7 +102,7 @@ std::optional GetRootNodeId( const std::unordered_map& roots_ids) { auto it = roots_ids.find(node); if (it == roots_ids.end()) { - return std::nullopt; + return c10::nullopt; } return it->second; } diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index b01b5ead3434b3..569cd5ee5e0a18 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -695,7 +695,7 @@ std::vector LazyGraphExecutor::SetTensorData( // resets the ir_value. We have already done the resetting as part // of ExtractIRAndPrepareTensorData to overlap with previous execution. tensor->data()->handle = handle; - tensor->data()->tensor_data = std::nullopt; + tensor->data()->tensor_data = c10::nullopt; } tensors_data.emplace_back(std::move(handle)); } diff --git a/torch/csrc/lazy/core/shape.cpp b/torch/csrc/lazy/core/shape.cpp index bf49cfacb99f61..939e2745ed3938 100644 --- a/torch/csrc/lazy/core/shape.cpp +++ b/torch/csrc/lazy/core/shape.cpp @@ -78,7 +78,7 @@ static c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) { std::vector> symbolic_dims; for (size_t i = 0; i < sizes.size(); i++) { if (is_symbolic->at(i)) { - symbolic_dims.emplace_back(std::nullopt); + symbolic_dims.emplace_back(c10::nullopt); } else { symbolic_dims.emplace_back(sizes.at(i)); } @@ -114,7 +114,7 @@ void applySymbolicShapesOnLT( auto res_symbolic = jit::calculateSymbolicShapesOnOp(&schema, converted_args); if (!res_symbolic) { for (auto& result_shape : result_shapes) { - result_shape = result_shape.with_symbolic_dims(std::nullopt); + result_shape = result_shape.with_symbolic_dims(c10::nullopt); } } else { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/lazy/core/shape.h b/torch/csrc/lazy/core/shape.h index 99e4a892bc589f..63566619fd1493 100644 --- a/torch/csrc/lazy/core/shape.h +++ b/torch/csrc/lazy/core/shape.h @@ -19,7 +19,7 @@ class TORCH_API Shape { Shape( at::ScalarType scalar_type, c10::ArrayRef sizes, - std::optional> is_symbolic = std::nullopt); + std::optional> is_symbolic = c10::nullopt); std::string to_string() const; @@ -64,7 +64,7 @@ class TORCH_API Shape { // Stores which dimmensions are symbolic // If nullopt, either it hasn't been initialized or the symbolic // dimmensions are not calculatable - std::optional> is_symbolic_ = std::nullopt; + std::optional> is_symbolic_ = c10::nullopt; }; TORCH_API std::ostream& operator<<(std::ostream& out, const Shape& shape); diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 76ddea597a784a..77eeaaa563187f 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -6,11 +6,11 @@ #include #include #include +#include #include #include #include #include -#include #include namespace torch { diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index 972af7dafc8baa..ba0571f87df4d3 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -143,13 +143,13 @@ void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) { // trimming. AssignIrValue(Value()); if (sync) { - data()->tensor_data = std::nullopt; + data()->tensor_data = c10::nullopt; } } void LazyTensor::SetIrValue(Value ir_value) { data()->handle = nullptr; - data()->tensor_data = std::nullopt; + data()->tensor_data = c10::nullopt; AssignIrValue(std::move(ir_value)); TryLimitGraphSize(); } @@ -158,7 +158,7 @@ void LazyTensor::SetInPlaceIrValue(Value ir_value) { auto tensor_shape = shape(); if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) { ir_value = - MakeCast(ir_value, tensor_shape.Get().scalar_type(), std::nullopt); + MakeCast(ir_value, tensor_shape.Get().scalar_type(), c10::nullopt); } SetIrValue(std::move(ir_value)); } @@ -253,7 +253,7 @@ at::Tensor LazyTensor::ToTensor(bool detached) { if (data()->ir_value || data()->handle != nullptr) { // If we have other authoritive sources, just drop our reference and // transfer it to the caller. - data()->tensor_data = std::nullopt; + data()->tensor_data = c10::nullopt; } else { // Otherwise we need to make a copy to prevent the caller changing our // version. diff --git a/torch/csrc/lazy/core/unique.h b/torch/csrc/lazy/core/unique.h index 3088da160860b7..fc09c8d71d7d8d 100644 --- a/torch/csrc/lazy/core/unique.h +++ b/torch/csrc/lazy/core/unique.h @@ -5,7 +5,7 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/lazy/core/util.h b/torch/csrc/lazy/core/util.h index bfd68b73355dfc..e535e5365f2277 100644 --- a/torch/csrc/lazy/core/util.h +++ b/torch/csrc/lazy/core/util.h @@ -9,8 +9,8 @@ #include #include +#include #include -#include namespace torch { namespace lazy { @@ -114,7 +114,7 @@ std::optional> ToOptionalVector( if (arrayRef) { return arrayRef->vec(); } - return std::nullopt; + return c10::nullopt; } template diff --git a/torch/csrc/lazy/python/python_util.cpp b/torch/csrc/lazy/python/python_util.cpp index 1ae663c519f562..90d9797e3fd357 100644 --- a/torch/csrc/lazy/python/python_util.cpp +++ b/torch/csrc/lazy/python/python_util.cpp @@ -13,12 +13,12 @@ namespace lazy { std::optional GetPythonFrameTop() { if (!Py_IsInitialized()) { - return std::nullopt; + return c10::nullopt; } pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); if (frame == nullptr) { - return std::nullopt; + return c10::nullopt; } SourceLocation loc; auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); diff --git a/torch/csrc/lazy/python/python_util.h b/torch/csrc/lazy/python/python_util.h index 271c694ee35ddc..456aafa8809716 100644 --- a/torch/csrc/lazy/python/python_util.h +++ b/torch/csrc/lazy/python/python_util.h @@ -1,7 +1,7 @@ #pragma once +#include #include #include -#include #include namespace torch { diff --git a/torch/csrc/lazy/ts_backend/ir_builder.h b/torch/csrc/lazy/ts_backend/ir_builder.h index 9fff33135a5c87..c5382923744345 100644 --- a/torch/csrc/lazy/ts_backend/ir_builder.h +++ b/torch/csrc/lazy/ts_backend/ir_builder.h @@ -34,7 +34,7 @@ struct TorchScriptIrBuilder : IrBuilder { const Value& input0, const at::ScalarType& dtype, const std::optional& stype = - std::nullopt) const override { + c10::nullopt) const override { return ReuseOrMakeNode(input0, dtype, stype); } NodePtr MakeTensorList(const OpList& inputs) const override { diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index a00ec260e5a145..42acc2c5df10a2 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -137,7 +137,7 @@ std::optional compute_target_device( } } } - return std::nullopt; + return c10::nullopt; } } // namespace diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 55d0b7f5a46543..78ae6a6f6e2e55 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -39,10 +39,10 @@ at::Tensor CreateLtcTensor( std::optional GetLtcDevice( const std::optional& device) { if (!device) { - return std::nullopt; + return c10::nullopt; } if (device->type() != at::kLazy) { - return std::nullopt; + return c10::nullopt; } return torch::lazy::atenDeviceToBackendDevice(*device); } @@ -235,7 +235,7 @@ at::Tensor LazyNativeFunctions::_to_copy( // captured IR, or we will try to convert an eager tensor back to a lazy one // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so // we can safely drop the device argument - device = std::nullopt; + device = c10::nullopt; torch::lazy::NodePtr node = torch::lazy::ReuseNode( lazy_self->GetIrValue(), @@ -307,7 +307,7 @@ at::Tensor LazyNativeFunctions::empty_strided_symint( std::optional pin_memory) { TORCH_LAZY_FN_COUNTER("lazy::"); at::Tensor t = - empty_symint(sym_size, dtype, layout, device, pin_memory, std::nullopt); + empty_symint(sym_size, dtype, layout, device, pin_memory, c10::nullopt); auto size = C10_AS_INTARRAYREF_SLOW(sym_size); auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride); return t.as_strided(size, stride, /*storage_offset=*/0); diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 687e8bf28787a8..e5daea953c57dd 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -200,7 +200,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { if (io_type == tagToIOType(tag)) { out.emplace_back(std::move(input)); } else { - out.emplace_back(std::nullopt); + out.emplace_back(c10::nullopt); } }; @@ -223,7 +223,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { arg.emplace_back(decode_tensor()); } if (found_undefined) { - push_value(*tag_it, std::nullopt); + push_value(*tag_it, c10::nullopt); } else { push_value(Tag::TensorListBegin, std::move(arg)); } @@ -236,7 +236,7 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { case Tag::UndefinedTensor: case Tag::Other: - push_value(*tag_it, std::nullopt); + push_value(*tag_it, c10::nullopt); break; case Tag::TERMINATOR: diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 71cb0c02bccc81..1c0a780370a9f0 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -91,7 +91,7 @@ using op_input_t = std::variant< TensorMetadata, std::vector, c10::IValue, - std::nullopt_t>; + c10::nullopt_t>; // ============================================================================ // == ExtraFields ============================================================= diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 25f93a2663dfb5..e5cae40c84e315 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -458,7 +458,7 @@ void initPythonBindings(PyObject* module) { [&](const c10::IValue& v) { out.append(torch::jit::toPyObject(v)); }, - [&](const std::nullopt_t&) { out.append(py::none()); }, + [&](const c10::nullopt_t&) { out.append(py::none()); }, [&](const auto& v) { out.append(py::cast(v)); }), input); } diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 8a3c4487ab7763..74d7877edadf14 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -290,12 +290,12 @@ std::vector unwind() { std::optional> libraryFor(void* addr) { if (!addr) { - return std::nullopt; + return c10::nullopt; } std::shared_lock lock(cache_mutex_); const LibraryInfo* library_info = unwind_cache.findLibraryFor((uint64_t)addr); if (!library_info) { - return std::nullopt; + return c10::nullopt; } return std::make_pair( library_info->name(), (uint64_t)addr - library_info->load_bias()); diff --git a/torch/csrc/profiler/unwind/unwind.h b/torch/csrc/profiler/unwind/unwind.h index bf93b88fa63dcb..1c302dfca445ff 100644 --- a/torch/csrc/profiler/unwind/unwind.h +++ b/torch/csrc/profiler/unwind/unwind.h @@ -1,7 +1,7 @@ #pragma once #include +#include #include -#include #include #include diff --git a/torch/csrc/profiler/unwind/unwind_error.h b/torch/csrc/profiler/unwind/unwind_error.h index cca8f8d12187b8..ae3630057f6a4c 100644 --- a/torch/csrc/profiler/unwind/unwind_error.h +++ b/torch/csrc/profiler/unwind/unwind_error.h @@ -1,6 +1,6 @@ #pragma once +#include #include -#include #include namespace torch::unwind { diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index 1a607909c45220..b06a479e70cc5f 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -9,10 +9,10 @@ #include #include +#include #include #include #include -#include // TODO: replace with pytorch/rfcs#43 when it is ready. #define SOFT_ASSERT(cond, ...) \ diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 6960034626d568..8d18180ed91955 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -449,7 +449,7 @@ void py_set_default_dtype(PyObject* obj) { THPDtype_Check(obj), "invalid dtype object: only floating-point types are supported as the default type"); auto scalar_type = ((THPDtype*)obj)->scalar_type; - set_default_tensor_type(/*backend=*/std::nullopt, scalar_type); + set_default_tensor_type(/*backend=*/c10::nullopt, scalar_type); } c10::DispatchKey get_default_dispatch_key() { diff --git a/torch/csrc/utils/nested.cpp b/torch/csrc/utils/nested.cpp index 360abda078df57..29ccf312851ea1 100644 --- a/torch/csrc/utils/nested.cpp +++ b/torch/csrc/utils/nested.cpp @@ -82,7 +82,7 @@ at::Tensor nested_tensor_ctor( final_device = new_list[0].device(); } auto out = at::_nested_tensor_from_tensor_list( - new_list, final_dtype, std::nullopt, final_device, pin_memory); + new_list, final_dtype, c10::nullopt, final_device, pin_memory); out.requires_grad_(args_requires_grad); return out; } diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index a1a1638f9120be..9aa80427929df0 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -268,7 +268,7 @@ static py::object dispatch_on_subclass( bool is_torch_function, const char* torch_function_name_str, std::optional maybe_mode_key = - std::nullopt) { + c10::nullopt) { py::object ret; for (auto& arg : overloaded_args) { py::object torch_function = @@ -1005,11 +1005,11 @@ std::string FunctionParameter::type_name() const { static inline std::optional parse_as_integer(const std::string& s) { if (s.empty()) - return std::nullopt; + return c10::nullopt; char* str_end = nullptr; long ans = strtol(s.c_str(), &str_end, 0); // *str_end == 0 if the entire string was parsed as an integer. - return (*str_end == 0) ? std::optional(ans) : std::nullopt; + return (*str_end == 0) ? std::optional(ans) : c10::nullopt; } /* diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 85a4d52bc16df8..8966131f9825f2 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -399,7 +399,7 @@ inline std::optional PythonArgs::optionalTensor(int i) { if (t.defined()) { return t; } else { - return std::nullopt; + return c10::nullopt; } } @@ -435,7 +435,7 @@ inline at::Scalar PythonArgs::scalarWithDefault( inline std::optional PythonArgs::scalarOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return scalar_slow(i); } @@ -771,7 +771,7 @@ inline at::ScalarType PythonArgs::scalartype(int i) { inline std::optional PythonArgs::scalartypeOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return scalartype(i); } @@ -796,7 +796,7 @@ inline at::Layout PythonArgs::layoutWithDefault( inline std::optional PythonArgs::layoutOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return layout(i); } @@ -837,7 +837,7 @@ inline at::Device PythonArgs::deviceWithDefault( inline std::optional PythonArgs::deviceOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return device(i); } @@ -863,7 +863,7 @@ inline std::vector parseDimnameList(PyObject* arg) { inline std::optional> PythonArgs:: toDimnameListOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return parseDimnameList(args[i]); } @@ -890,7 +890,7 @@ inline at::MemoryFormat PythonArgs::memoryformat(int i) { inline std::optional PythonArgs::memoryformatOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return memoryformat(i); } @@ -918,7 +918,7 @@ inline std::string PythonArgs::stringWithDefault( inline std::optional PythonArgs::stringOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return THPUtils_unpackString(args[i]); } @@ -936,7 +936,7 @@ inline c10::string_view PythonArgs::stringViewWithDefault( inline std::optional PythonArgs::stringViewOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return THPUtils_unpackStringView(args[i]); } @@ -990,26 +990,26 @@ inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) { inline std::optional PythonArgs::toInt64Optional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return toInt64(i); } inline std::optional PythonArgs::toSymIntOptional(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return toSymInt(i); } inline std::optional PythonArgs::toBoolOptional(int i) { if (!args[i]) { - return std::nullopt; + return c10::nullopt; } return toBool(i); } inline std::optional PythonArgs::toDoubleOptional(int i) { if (!args[i]) { - return std::nullopt; + return c10::nullopt; } return toDouble(i); } @@ -1071,7 +1071,7 @@ inline bool PythonArgs::isNone(int i) { inline std::optional PythonArgs::generator(int i) { if (!args[i]) - return std::nullopt; + return c10::nullopt; return reinterpret_cast(args[i])->cdata; } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 2d18978018a1b4..ec0af99842d2e5 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -65,8 +65,8 @@ static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { template inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { auto mb_key = std::string(key).empty() - ? std::nullopt - : std::make_optional(c10::parseDispatchKey(key)); + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(key)); if (mb_key) { return torch::dispatch(*mb_key, std::forward(raw_f)); } else { @@ -217,7 +217,7 @@ static py::object ophandle_call_boxed( handle.schema(), std::move(args), kwargs, - /*self=*/std::nullopt); + /*self=*/c10::nullopt); { pybind11::gil_scoped_release no_gil_guard; handle.callBoxed(stack); @@ -264,7 +264,7 @@ void initDispatchBindings(PyObject* module) { handle.schema(), std::move(args), kwargs, - /*self=*/std::nullopt); + /*self=*/c10::nullopt); { pybind11::gil_scoped_release no_gil_guard; handle.redispatchBoxed(keyset, &stack); @@ -477,8 +477,8 @@ void initDispatchBindings(PyObject* module) { parseKind(kind), std::move(name), std::string(dispatch).empty() - ? std::nullopt - : std::make_optional(c10::parseDispatchKey(dispatch)), + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch)), "/dev/null", // temporary workaround linenum); END_HANDLE_TH_ERRORS_PYBIND @@ -814,8 +814,8 @@ void initDispatchBindings(PyObject* module) { "_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() - ? std::nullopt - : std::make_optional(c10::parseDispatchKey(dispatch_key)); + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { @@ -830,7 +830,7 @@ void initDispatchBindings(PyObject* module) { try { return c10::parseDispatchKey(dispatch_key); } catch (const c10::Error& err) { - return std::nullopt; + return c10::nullopt; } }); @@ -838,8 +838,8 @@ void initDispatchBindings(PyObject* module) { "_dispatch_get_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() - ? std::nullopt - : std::make_optional(c10::parseDispatchKey(dispatch_key)); + ? c10::nullopt + : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); std::vector names; @@ -888,7 +888,7 @@ void initDispatchBindings(PyObject* module) { "Expected device_type string to not have a device index; got ", device_type); return c10::toString( - c10::computeDispatchKey(std::nullopt, std::nullopt, device)); + c10::computeDispatchKey(c10::nullopt, c10::nullopt, device)); }); m.def("_are_functorch_transforms_active", []() { diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h index af63d1efba5458..bc7b5c263e0d91 100644 --- a/torch/csrc/utils/python_raii.h +++ b/torch/csrc/utils/python_raii.h @@ -1,5 +1,5 @@ +#include #include -#include #include namespace torch::impl { @@ -17,7 +17,7 @@ struct RAIIContextManager { } void exit() { - guard_ = std::nullopt; + guard_ = c10::nullopt; } private: @@ -50,7 +50,7 @@ struct DeprecatedRAIIContextManager { void enter() {} void exit() { - guard_ = std::nullopt; + guard_ = c10::nullopt; } private: diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index e82c30a8c98f75..15738b1a67e16c 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -144,7 +144,7 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { py::gil_scoped_acquire acquire; const auto& r = getPyObj().attr("maybe_as_int")(); if (r.is_none()) { - return std::nullopt; + return c10::nullopt; } else { return r.cast(); } diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp index 61eecc7cf0079e..0caa5b254d279f 100644 --- a/torch/csrc/utils/schema_info.cpp +++ b/torch/csrc/utils/schema_info.cpp @@ -8,7 +8,7 @@ void SchemaInfo::addArgumentValue( const at::IValue& value) { std::optional index = schema_.argumentIndexWithName(name); TORCH_INTERNAL_ASSERT( - index != std::nullopt, "Schema has no argument named ", name); + index != c10::nullopt, "Schema has no argument named ", name); value_map_[name] = value; alias_maps_current_ = false; } @@ -102,7 +102,7 @@ bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) { } bool SchemaInfo::has_argument(c10::string_view name) { - return schema_.argumentIndexWithName(name) != std::nullopt; + return schema_.argumentIndexWithName(name) != c10::nullopt; } bool SchemaInfo::is_mutable(c10::string_view name) { diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index e66c99bc4d4939..4fd398d1a8fafd 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -28,8 +28,8 @@ #include #include #include +#include #include -#include #include #include @@ -53,7 +53,7 @@ thread_local bool kOnlyLiftCPUTensors = false; TensorOptions build_options( c10::TensorOptions options, at::ScalarType scalar_type, - const std::optional& device = std::nullopt) { + const std::optional& device = c10::nullopt) { options = options.dtype(scalar_type); if (device.has_value()) { return options.device(device); @@ -1257,7 +1257,7 @@ void _validate_sparse_coo_tensor_args( Tensor values = internal_new_from_data( options, scalar_type, - std::nullopt, + c10::nullopt, r.pyobject(1), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1266,7 +1266,7 @@ void _validate_sparse_coo_tensor_args( Tensor indices = internal_new_from_data( values.options(), kLong, - std::nullopt, + c10::nullopt, r.pyobject(0), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1298,7 +1298,7 @@ void _validate_sparse_compressed_tensor_args( Tensor values = internal_new_from_data( options, scalar_type, - std::nullopt, + c10::nullopt, r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1307,7 +1307,7 @@ void _validate_sparse_compressed_tensor_args( Tensor compressed_indices = internal_new_from_data( values.options(), kInt, - std::nullopt, + c10::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1315,7 +1315,7 @@ void _validate_sparse_compressed_tensor_args( Tensor plain_indices = internal_new_from_data( values.options(), kInt, - std::nullopt, + c10::nullopt, r.pyobject(ARG_PLAIN_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1369,7 +1369,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor values = internal_new_from_data( options, scalar_type, - std::nullopt, + c10::nullopt, r.pyobject(ARG_VALUES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1378,7 +1378,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor compressed_indices = internal_new_from_data( values.options(), kInt, - std::nullopt, + c10::nullopt, r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, @@ -1386,7 +1386,7 @@ void _validate_sparse_compressed_tensor_args_template( Tensor plain_indices = internal_new_from_data( values.options(), kInt, - std::nullopt, + c10::nullopt, r.pyobject(ARG_PLAIN_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 2eb8ba7a1cbbbb..8ca4511435737e 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -19,7 +19,7 @@ struct StashTorchDispatchModeGuard { } ~StashTorchDispatchModeGuard() { - if (saved_mode_key_ != std::nullopt) { + if (saved_mode_key_ != c10::nullopt) { c10::impl::TorchDispatchModeTLS::set_mode( saved_mode_, saved_mode_key_.value()); } else { diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index 135c49ac76a925..e27721c349864c 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -47,7 +47,7 @@ struct arg { // Explicit constructor. explicit arg(std::string name) - : name_(std::move(name)), value_(std::nullopt) {} + : name_(std::move(name)), value_(c10::nullopt) {} // Assignment operator. This enables the pybind-like syntax of // torch::arg("name") = value. arg& operator=(const c10::IValue& rhs) { diff --git a/torch/library.h b/torch/library.h index d75e6b01982120..c860f4c2034444 100644 --- a/torch/library.h +++ b/torch/library.h @@ -215,7 +215,7 @@ class TORCH_API CppFunction final { static CppFunction makeFromBoxedKernel(c10::BoxedKernel kernel) { return CppFunction( c10::KernelFunction::makeFromBoxedKernel(std::move(kernel)), - /* cpp_signature */ std::nullopt, // not known for boxed functions + /* cpp_signature */ c10::nullopt, // not known for boxed functions /* schema */ nullptr); } @@ -337,7 +337,7 @@ template inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { CppFunction f(std::forward(raw_f)); if (k == c10::DispatchKey::CatchAll) { - f.dispatch_key_ = std::nullopt; + f.dispatch_key_ = c10::nullopt; } else { f.dispatch_key_ = k; } @@ -930,7 +930,7 @@ class TorchLibraryInit final { torch::Library::DEF, \ &TORCH_LIBRARY_init_##ns, \ #ns, \ - std::nullopt, \ + c10::nullopt, \ __FILE__, \ __LINE__); \ void TORCH_LIBRARY_init_##ns(torch::Library& m) @@ -960,7 +960,7 @@ class TorchLibraryInit final { torch::Library::FRAGMENT, \ &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ #ns, \ - std::nullopt, \ + c10::nullopt, \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ @@ -1024,7 +1024,7 @@ class TorchLibraryInit final { ? &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid) \ : [](torch::Library&) -> void {}), \ #ns, \ - std::make_optional(c10::DispatchKey::k), \ + c10::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__); \ void C10_CONCATENATE( \ @@ -1039,13 +1039,13 @@ class TorchLibraryInit final { /// \private #define MAKE_TORCH_LIBRARY(ns) \ - torch::Library(torch::Library::DEF, #ns, std::nullopt, __FILE__, __LINE__) + torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__) /// \private #define MAKE_TORCH_LIBRARY_IMPL(ns, k) \ torch::Library( \ torch::Library::IMPL, \ #ns, \ - std::make_optional(c10::DispatchKey::k), \ + c10::make_optional(c10::DispatchKey::k), \ __FILE__, \ __LINE__) From 2d01f877373d3255fa8b77c714c2ca17d08e6126 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Sat, 15 Jun 2024 02:05:27 +0000 Subject: [PATCH 047/171] Enable torch.empty for float8 dtypes + deterministic mode + cpu (#128744) Summary: Enables creating empty float8 tensors for: * cuda when `torch.use_deterministic_algorithms` is set to True * cpu for all settings of `torch.use_deterministic_algorithms` Context for NaN values of float8_e4m3fn and float8_e5m2: https://arxiv.org/pdf/2209.05433, Section 3, Table 1 Context for NaN values of float8_e4m3fnuz and float8_e5m2fnuz: https://arxiv.org/pdf/2206.02915, Section 3.2, "instead of reserving one exponent field to represent Inf and NaN, we reserve only a single codeword (corresponding to negative zero)" Test Plan: ``` python test/test_quantization.py -k test_empty ``` Reviewers: Subscribers: Tasks: Tags: Fixes https://github.com/pytorch/pytorch/issues/128733 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128744 Approved by: https://github.com/malfet, https://github.com/drisspg --- aten/src/ATen/native/TensorFactories.h | 6 +++--- aten/src/ATen/native/cpu/FillKernel.cpp | 8 ++++++++ c10/util/Float8_e5m2-inl.h | 5 ++++- c10/util/Float8_e5m2fnuz-inl.h | 5 +++++ test/quantization/core/experimental/test_float8.py | 9 +++++++++ 5 files changed, 29 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 58cbbfc4df3347..d7a4d6483f6ed8 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -103,10 +103,10 @@ inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tens // with max value if it is integer type inline Tensor& fill_empty_deterministic_(Tensor& tensor) { if (tensor.is_floating_point() || tensor.is_complex()) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() { + AT_DISPATCH_V2( + tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { tensor.fill_(std::numeric_limits::quiet_NaN()); - }); + }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf); } else { AT_DISPATCH_V2( tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp index a24de0b48f9ef8..43a562306e3414 100644 --- a/aten/src/ATen/native/cpu/FillKernel.cpp +++ b/aten/src/ATen/native/cpu/FillKernel.cpp @@ -43,6 +43,14 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::ComplexHalf) { fill_non_native_type>(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e4m3fn) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e5m2) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) { + fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) { + fill_non_native_type(iter, value_scalar); } else { AT_DISPATCH_V2( iter.dtype(), "fill_cpu", AT_WRAP([&]() { diff --git a/c10/util/Float8_e5m2-inl.h b/c10/util/Float8_e5m2-inl.h index 7800ceb29924a7..5a5c1a5fc9b5b5 100644 --- a/c10/util/Float8_e5m2-inl.h +++ b/c10/util/Float8_e5m2-inl.h @@ -235,7 +235,7 @@ class numeric_limits { static constexpr bool is_specialized = true; static constexpr bool is_exact = false; static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = false; + static constexpr bool has_quiet_NaN = true; static constexpr bool has_signaling_NaN = false; static constexpr auto has_denorm = true; static constexpr auto has_denorm_loss = true; @@ -273,6 +273,9 @@ class numeric_limits { static constexpr c10::Float8_e5m2 infinity() { return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits()); } + static constexpr c10::Float8_e5m2 quiet_NaN() { + return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits()); + } static constexpr c10::Float8_e5m2 denorm_min() { return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits()); } diff --git a/c10/util/Float8_e5m2fnuz-inl.h b/c10/util/Float8_e5m2fnuz-inl.h index 3af233a87b8444..d81054cbee351a 100644 --- a/c10/util/Float8_e5m2fnuz-inl.h +++ b/c10/util/Float8_e5m2fnuz-inl.h @@ -270,6 +270,11 @@ class numeric_limits { static constexpr c10::Float8_e5m2fnuz infinity() { return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); } + // TODO(future): we are mapping neg_zero to both inf and NaN, this is + // surprising and we should figure out what to do about it. + static constexpr c10::Float8_e5m2fnuz quiet_NaN() { + return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits()); + } static constexpr c10::Float8_e5m2fnuz denorm_min() { return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits()); } diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py index 1f735f29e322b9..feb14e2cbad8b1 100644 --- a/test/quantization/core/experimental/test_float8.py +++ b/test/quantization/core/experimental/test_float8.py @@ -9,6 +9,7 @@ instantiate_device_type_tests, ) from torch.testing._internal.common_utils import ( + DeterministicGuard, IS_WINDOWS, parametrize, run_tests, @@ -259,6 +260,14 @@ def test_type_promotion_fails(self, dtype, device): ): x + y + @dtypes(*FLOAT8_DTYPES) + @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) + def test_empty(self, dtype, device): + with DeterministicGuard(torch.are_deterministic_algorithms_enabled()): + for use_deterministic in (True, False): + torch.use_deterministic_algorithms(use_deterministic) + x = torch.empty(4, 4, device=device, dtype=dtype) + instantiate_device_type_tests(TestFloat8Dtype, globals()) From 62a0e39ced99828a59f410ff79e9f8f61c74c928 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 14 Jun 2024 16:29:28 -0700 Subject: [PATCH 048/171] [dynamo][inlining-nn-modules] Update tests with new expected counts (#128463) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128463 Approved by: https://github.com/yanboliang --- .../test_dynamo_with_onnxruntime_backend.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 0c7a141d6a7ac3..13dbff3458d9e6 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -471,7 +471,12 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 + if torch._dynamo.config.inline_inbuilt_nn_modules: + # with inlining and dynamic=True, we have more graph captures + number_of_captured_graphs = 3 if test_backward else 2 + else: + number_of_captured_graphs = 2 if test_backward else 1 + execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, @@ -564,8 +569,14 @@ def generate_example_inputs(batch: int, seq: int, hidden_size: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 + if torch._dynamo.config.inline_inbuilt_nn_modules: + # with inlining and dynamic=True, we have more graph captures + number_of_captured_graphs = 3 if test_backward else 2 + else: + number_of_captured_graphs = 2 if test_backward else 1 + execution_count = len(example_args_collection) * number_of_captured_graphs + self._assert_counting_information( local_ort, expected_execution_count=execution_count, @@ -649,7 +660,11 @@ def generate_example_inputs(batch: int, seq: int): if test_local_backend: assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 + if torch._dynamo.config.inline_inbuilt_nn_modules: + # with inlining and dynamic=True, we have more graph captures + number_of_captured_graphs = 3 if test_backward else 2 + else: + number_of_captured_graphs = 2 if test_backward else 1 execution_count = len(example_args_collection) * number_of_captured_graphs self._assert_counting_information( local_ort, From 7e092a62e6192f3b7daa91f2960ecea168220fd8 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 14 Jun 2024 13:26:00 -0700 Subject: [PATCH 049/171] [dynamo] Support weakref objects (#128533) Fixes https://github.com/pytorch/pytorch/issues/125720 I was earlier worried that DELETE_* or STORE_* on referent values should result in a graph break, because they could invalidate the weak ref. But then @zou3519 pointed out that weakref invalidation will happen EVENTUALLY, CPython provides no guarantees when the weakref will be invalidated (even when the user calls del x and x is the last reference). So any code that relies on del x to invalidate the weakref of x right away is BAD code. CPython provide no guarantees. Therefore we can (ab)use this nuance, and can just ignore DELETE_* or STORE_* on the referent objects. The only corner case is when Dynamo is reconstructing the weakref object. Dynamo will have a hard time being correct here, so just SKIP_FRAME on such a case. This is rare. Cpython notes 1) https://docs.python.org/3/library/weakref.html 2) https://docs.python.org/3/reference/datamodel.html#index-2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128533 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 61 +++++++++++++++++++++ torch/_dynamo/guards.py | 8 +++ torch/_dynamo/source.py | 13 +++++ torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builder.py | 15 ++++++ torch/_dynamo/variables/user_defined.py | 31 ++++++++++- torch/csrc/dynamo/guards.cpp | 71 +++++++++++++++++++++++++ 7 files changed, 199 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c30210a398407b..c53f8a49d6a64f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8,6 +8,7 @@ import contextlib import copy import functools +import gc import inspect import itertools import random @@ -4720,6 +4721,66 @@ def fn(instances): self.assertEqual(type(actual), type(expected)) self.assertEqual(actual.__dict__, expected.__dict__) + def test_weakref(self): + def fn(x_weak, weight, y): + if x_weak is not None and x_weak() is not weight: + return torch.sin(y) + return torch.cos(y) + + weight = torch.randn(4) + y = torch.randn(4) + x_weak = weakref.ref(weight) + + ref = fn(x_weak, weight, y) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x_weak, weight, y) + self.assertEqual(ref, res) + + def test_weakref_reconstruct(self): + def fn(x_weak, weight, y): + y = torch.sin(y) + referent = x_weak() + torch._dynamo.graph_break() + if referent is not weight: + return torch.sin(y) + return torch.cos(y) + + weight = torch.randn(4) + y = torch.randn(4) + x_weak = weakref.ref(weight) + + ref = fn(x_weak, weight, y) + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnt) + res = opt_fn(x_weak, weight, y) + self.assertEqual(ref, res) + self.assertEqual(cnt.frame_count, 2) + + def test_weakref_del(self): + def fn(x_weak, y): + x = x_weak() + if x is not None: + return torch.sin(y) + return torch.cos(y) + + weight = torch.randn(4) + x_weak = weakref.ref(weight) + y = torch.randn(4) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + ref = fn(x_weak, y) + res = opt_fn(x_weak, y) + self.assertEqual(ref, res) + + del weight + gc.collect() + ref = fn(x_weak, y) + res = opt_fn(x_weak, y) + self.assertEqual(ref, res) + def test_storage_resize_forward_full_graph(self): class TestModule(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index fc3f12847a756a..0c882ad16fcffe 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -91,6 +91,7 @@ ShapeEnvSource, TupleIteratorGetItemSource, TypeSource, + WeakRefCallSource, ) from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 from .utils import ( @@ -1006,6 +1007,13 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif isinstance(source, WeakRefCallSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.weakref_call_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 69423712c53c9c..3afc91a944cdfc 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -142,6 +142,19 @@ def name(self): return f"G[{repr(self.global_name)}]()" +@dataclasses.dataclass(frozen=True) +class WeakRefCallSource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + codegen.extend_output(create_call_function(0, True)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}()" + + @dataclasses.dataclass(frozen=True) class AttrSource(ChainedSource): member: str diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 9ffdd64fbc962e..152698568c7c67 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -96,6 +96,7 @@ RemovableHandleVariable, UserDefinedClassVariable, UserDefinedObjectVariable, + WeakRefVariable, ) __all__ = [ diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2f7b50d36c26e1..8fe9154c53a644 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -14,6 +14,7 @@ import re import sys import types +import weakref from typing import Any, List, NamedTuple, Optional, Union from torch.utils._sympy.value_ranges import ValueRanges @@ -184,6 +185,7 @@ SourcelessGraphModuleVariable, UserDefinedClassVariable, UserDefinedObjectVariable, + WeakRefVariable, ) @@ -383,6 +385,8 @@ def _type_dispatch(cls): ((slice, range), cls.wrap_slice_range), (tuple(common_constant_types), cls.wrap_literal), (re.Pattern, cls.wrap_regex_pattern), + (weakref.ReferenceType, cls.wrap_weakref), + (torch.utils.hooks.RemovableHandle, cls.wrap_removable_handle), ] if config.trace_numpy and np: @@ -401,6 +405,17 @@ def wrap_regex_pattern(self, value: re.Pattern): self.install_guards(GuardBuilder.ID_MATCH) return RegexPatternVariable(value) + def wrap_weakref(self, value: weakref.ReferenceType): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WeakRefVariable(value, source=self.source) + + def wrap_removable_handle(self, value): + # This means that the removable handle was created in some other frame. + # Our current infra requires the hook to be registered and removed in + # the same frame. So graph break. + # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks + unimplemented("unregistered hook removable handle") + @classmethod @functools.lru_cache(None) def _id_dispatch(cls): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 6c6d3182b66031..2b97d921b73b11 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -37,7 +37,13 @@ from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import ObservedException, unimplemented from ..guards import GuardBuilder, install_guard -from ..source import AttrSource, GetItemSource, ODictGetItemSource, RandomValueSource +from ..source import ( + AttrSource, + GetItemSource, + ODictGetItemSource, + RandomValueSource, + WeakRefCallSource, +) from ..utils import ( all_hook_names, build_checkpoint_variable, @@ -1078,6 +1084,29 @@ def call_method( ) +class WeakRefVariable(UserDefinedObjectVariable): + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, **kwargs): + super().__init__(value, **kwargs) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + call_source = None + referent = self.value() + + if self.source: + from .builder import VariableBuilder + + call_source = WeakRefCallSource(self.source) + return VariableBuilder(tx, call_source)(referent) + else: + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, referent) + + class KeyedJaggedTensorVariable(UserDefinedObjectVariable): @staticmethod def is_matching_object(obj): diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index d2eb41f51115df..a9359f09b27c27 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3173,6 +3173,51 @@ class GlobalWeakRefGuardAccessor : public GuardAccessor { PyObject* _global_name; }; +/** + * Implements weakref call - x_weak() + */ +class WeakRefCallGuardAccessor : public GuardAccessor { + public: + WeakRefCallGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj) override { // borrowed ref + if (!PyWeakref_Check(obj)) { + return false; + } + + PyObject* x = PyWeakref_GetObject(obj); // borrowed ref + return _guard_manager->check_nopybind(x); + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + if (!PyWeakref_Check(obj)) { + return GuardDebugInfo( + false, std::string("Not a weakref obj ") + get_source(), 0); + } + + PyObject* x = PyWeakref_GetObject(obj); // borrowed ref + return _guard_manager->check_verbose_nopybind(x); + } + + std::string repr() const override { + return "WeakRefCallGuardAccessor()"; + } +}; + /** * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to * supply accessor as a python function. This is useful for from_numpy source. @@ -3500,6 +3545,12 @@ PyObject* torch_c_dynamo_guards_init() { GuardAccessor, std::unique_ptr>(py_m, "TypeGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + WeakRefCallGuardAccessor, + GuardAccessor, + std::unique_ptr>( + py_m, "WeakRefCallGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< TupleIteratorGetItemAccessor, GuardAccessor, @@ -3790,6 +3841,26 @@ PyObject* torch_c_dynamo_guards_init() { py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers + .def( + "weakref_call_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__weakref_call_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers .def( "tuple_iterator_getitem_manager", &GuardManager::get_child_manager, From 9ebf77b13b2c7b3b04b54028c9080ee149088605 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 15 Jun 2024 03:02:00 +0000 Subject: [PATCH 050/171] Fix windows inductor defination issue (#128686) Changes: 1. Add memory align macro support on Windows. 2. Fix `#pragma unroll` not support on MSVC cl compiler. `#pragma unroll` occur error on msvc `cl` compiler, but it would be supported on Windows `clang`. We'd better disable it only on `__msvc_cl__` compiler, and get better performance if we enabled `clang`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128686 Approved by: https://github.com/jgong5, https://github.com/jansel --- .../src/ATen/cpu/vec/vec256/vec256_bfloat16.h | 8 ++++ aten/src/ATen/cpu/vec/vec256/vec256_double.h | 4 ++ aten/src/ATen/cpu/vec/vec256/vec256_float.h | 4 ++ .../ATen/cpu/vec/vec256/vec256_float_neon.h | 8 ++++ .../ATen/cpu/vec/vec256/vec256_half_neon.h | 8 ++++ .../src/ATen/cpu/vec/vec512/vec512_bfloat16.h | 40 +++++++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_double.h | 4 ++ aten/src/ATen/cpu/vec/vec512/vec512_float.h | 4 ++ aten/src/ATen/cpu/vec/vec_base.h | 9 +++++ aten/src/ATen/cpu/vec/vec_mask.h | 2 + torch/_inductor/codecache.py | 11 +++++ 11 files changed, 102 insertions(+) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index 19e0320d8abf6b..e567c1925be840 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -794,12 +794,16 @@ Vectorized inline clamp_min(const Vectorized& a, const Vecto template <> inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } @@ -992,12 +996,16 @@ Vectorized inline clamp_min(const Vectorized& a, const Vectorized inline void convert(const Half* src, Half* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h index bed6da627af2dc..168fe4ed7f9693 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h @@ -416,11 +416,15 @@ inline Vectorized Vectorized::le(const Vectorized& other template <> inline void convert(const double* src, double* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i)); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index 0e3664cd37b6ab..0d0fe99252a7d0 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -512,11 +512,15 @@ inline Vectorized Vectorized::le(const Vectorized& other) c template <> inline void convert(const float* src, float* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i)); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h index a5b993f2b9e10e..d2324818cc8b93 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h @@ -823,12 +823,16 @@ inline Vectorized Vectorized::le(const Vectorized& other) c template <> inline void convert(const float* src, int32_t* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4))); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = static_cast(src[i]); } @@ -837,12 +841,16 @@ inline void convert(const float* src, int32_t* dst, int64_t n) { template <> inline void convert(const int32_t* src, float* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4))); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = static_cast(src[i]); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h index b150106c435978..0b51972a029b44 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h @@ -765,13 +765,17 @@ inline Vectorized Vectorized::le( template <> inline void convert(const float16_t* src, int16_t* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); vst1q_s16(dst + i + 8, vcvtq_s16_f16(vld1q_f16(src + i + 8))); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = static_cast(src[i]); } @@ -780,13 +784,17 @@ inline void convert(const float16_t* src, int16_t* dst, int64_t n) { template <> inline void convert(const int16_t* src, float16_t* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); vst1q_f16(dst + i + 8, vcvtq_f16_s16(vld1q_s16(src + i + 8))); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = static_cast(src[i]); } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h index c7132349418de0..df91a82d20b340 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -914,12 +914,16 @@ Vectorized inline clamp_min(const Vectorized& a, const Vecto template <> inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } @@ -986,7 +990,9 @@ static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { // j0-j15 n0-n15 // k0-k15 o0-o15 // l0-l15 p0-p15 +#ifndef __msvc_cl__ #pragma unroll(4) +#endif for (int i = 0; i < 4; i++) { r[i] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i]), t[i + 4], 0x01); r[i + 4] = _mm512_inserti64x4(_mm512_castsi256_si512(t[i + 8]), t[i + 12], 0x01); @@ -998,7 +1004,9 @@ static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { // u3: c4c5 d4b5 c6c7 d6b7 c12c13 d12d13 c14c15 d14d15 g4g5 h4h5 g6g7 h6h7 g12g13 h12h13 g14g15 h14h15 // i j m n // k l o p +#ifndef __msvc_cl__ #pragma unroll(4) +#endif for (int i = 0; i < 8; i += 2) { u[i] = _mm512_unpacklo_epi32(r[i], r[i + 1]); u[i + 1] = _mm512_unpackhi_epi32(r[i], r[i + 1]); @@ -1061,7 +1069,9 @@ static inline void _transpose_mxn_half_16_16(__m256i t[], __m512i u[]) { // 12-- 13-- // 6-- 7-- // 14-- 15-- +#ifndef __msvc_cl__ #pragma unroll(4) +#endif for (int i = 0; i < 4; i++) { u[i] = _mm512_permutex2var_epi16(r[i], const1, r[i + 4]); u[i + 4] = _mm512_permutex2var_epi16(r[i], const2, r[i + 4]); @@ -1095,7 +1105,9 @@ inline void transpose_mxn( // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 +#ifndef __msvc_cl__ #pragma unroll(16) +#endif for (int i = 0; i < 16; i++) { t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); } @@ -1103,7 +1115,9 @@ inline void transpose_mxn( __m512i u[8]; _transpose_mxn_half_16_16(t, u); +#ifndef __msvc_cl__ #pragma unroll(8) +#endif for (int i = 0; i < 8; i++) { _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), @@ -1125,7 +1139,9 @@ inline void transpose_mxn( __m256i t[16]; // load from src to registers // Same matrix indices as above transpose_mxn +#ifndef __msvc_cl__ #pragma unroll(16) +#endif for (int i = 0; i < 16; i++) { t[i] = _mm256_loadu_si256(reinterpret_cast(src + i * ld_src)); } @@ -1133,7 +1149,9 @@ inline void transpose_mxn( __m512i u[8]; _transpose_mxn_half_16_16(t, u); +#ifndef __msvc_cl__ #pragma unroll(8) +#endif for (int i = 0; i < 8; i++) { _mm256_storeu_si256( reinterpret_cast<__m256i*>(dst + (i * 2) * ld_dst), @@ -1164,7 +1182,9 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // t[16]: 512 544 513 545 514 546 515 547 520 552 521 553 522 554 523 555 528 ... 571 // ... // t[31]: 964 996 965 997 966 998 967 999 972 1004 973 1005 974 1006 975 1007 980 ... 1023 +#ifndef __msvc_cl__ #pragma unroll(16) +#endif for (int i = 0; i < 16; ++i) { d[i * 2] = _mm512_unpacklo_epi16(r[i * 2], r[i * 2 + 1]); d[i * 2 + 1] = _mm512_unpackhi_epi16(r[i * 2], r[i * 2 + 1]); @@ -1189,7 +1209,9 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // t[16]: 512 544 576 608 513 545 577 609 520 552 584 616 521 553 585 617 528 ... 633 // ... // t[31]: 902 934 966 998 903 935 967 999 910 942 974 1006 911 943 975 1007 918 ... 1023 +#ifndef __msvc_cl__ #pragma unroll(8) +#endif for (int i = 0; i < 8; ++i) { r[i * 4] = _mm512_unpacklo_epi32(d[i * 4], d[i * 4 + 2]); r[i * 4 + 1] = _mm512_unpackhi_epi32(d[i * 4], d[i * 4 + 2]); @@ -1216,7 +1238,9 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // t[16]: 512 544 576 608 640 672 704 736 520 552 584 616 648 680 712 744 528 ... 760 // ... // t[31]: 775 807 839 871 903 935 967 999 783 815 847 879 911 943 975 1007 791 ... 1023 +#ifndef __msvc_cl__ #pragma unroll(4) +#endif for (int i = 0; i < 4; ++i) { d[i * 8] = _mm512_unpacklo_epi64(r[i * 8], r[i * 8 + 4]); d[i * 8 + 1] = _mm512_unpackhi_epi64(r[i * 8], r[i * 8 + 4]); @@ -1265,7 +1289,9 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { 0x000000000000000a, 0x0000000000000003, 0x0000000000000002); +#ifndef __msvc_cl__ #pragma unroll(8) +#endif for (int i = 0; i < 8; ++i) { r[i] = _mm512_permutex2var_epi64(d[i], /*idx*/const1, d[i + 8]); r[i + 8] = _mm512_permutex2var_epi64(d[i], /*idx*/const2, d[i + 8]); @@ -1310,7 +1336,9 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { 0x0000000000000006, 0x0000000000000005, 0x0000000000000004); +#ifndef __msvc_cl__ #pragma unroll(16) +#endif for (int i = 0; i < 16; ++i) { d[i] = _mm512_permutex2var_epi64(r[i], /*idx*/const3, r[i + 16]); d[i + 16] = _mm512_permutex2var_epi64(r[i], /*idx*/const4, r[i + 16]); @@ -1327,7 +1355,9 @@ inline void transpose_mxn( int64_t ld_dst) { // Load from memory __m512i r[32]; +#ifndef __msvc_cl__ #pragma unroll(32) +#endif for (int i = 0; i < 32; ++i) { r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); } @@ -1336,7 +1366,9 @@ inline void transpose_mxn( _transpose_mxn_half_32_32(r, d); // Store to dst +#ifndef __msvc_cl__ #pragma unroll(32) +#endif for (int i = 0; i < 32; ++i) { _mm512_storeu_si512(dst + i* ld_dst, d[i]); } @@ -1350,7 +1382,9 @@ inline void transpose_mxn( int64_t ld_dst) { // Load from memory __m512i r[32]; +#ifndef __msvc_cl__ #pragma unroll(32) +#endif for (int i = 0; i < 32; ++i) { r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); } @@ -1359,7 +1393,9 @@ inline void transpose_mxn( _transpose_mxn_half_32_32(r, d); // Store to dst +#ifndef __msvc_cl__ #pragma unroll(32) +#endif for (int i = 0; i < 32; ++i) { _mm512_storeu_si512(dst + i* ld_dst, d[i]); } @@ -1514,12 +1550,16 @@ Vectorized inline clamp_min(const Vectorized& a, const Vectorized inline void convert(const Half* src, Half* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h index 508ab257e603bb..ae48dc8a3f30a6 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h @@ -443,11 +443,15 @@ inline Vectorized Vectorized::le(const Vectorized& other template <> inline void convert(const double* src, double* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index a08df3c141a380..40e9610b67b303 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -552,11 +552,15 @@ inline Vectorized Vectorized::le(const Vectorized& other) c template <> inline void convert(const float* src, float* dst, int64_t n) { int64_t i; +#ifndef __msvc_cl__ #pragma unroll +#endif for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); } +#ifndef __msvc_cl__ #pragma unroll +#endif for (; i < n; i++) { dst[i] = src[i]; } diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index d696c97b594978..9ff9019320e549 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -42,6 +42,15 @@ #define __FORCE_INLINE __forceinline #endif +#if defined(_MSC_FULL_VER) +/* +https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170 +Use _MSC_FULL_VER to identify current compiler is msvc, +Windows llvm will not have this defination. +*/ +#define __msvc_cl__ +#endif + // These macros helped us unify vec_base.h #ifdef CPU_CAPABILITY_AVX512 #if defined(__GNUC__) diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index 90f0f98962d900..6b773c40ca8c9c 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -127,7 +127,9 @@ class VecMask { static VecMask from(U* b) { using int_t = int_same_size_t; __at_align__ T mask[size()]; +#ifndef __msvc_cl__ #pragma unroll +#endif for (int i = 0; i < size(); i++) { *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; } diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 745db1397eb6ff..3d265f181b159a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1302,7 +1302,18 @@ class VecISA: #include #endif +#ifdef __APPLE__ +// Fix Mac OS UT failed. __attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; +#else +#if defined(_WIN32) +#define __at_align__ __declspec(align(64)) +#else +#define __at_align__ __attribute__((aligned(64))) +#endif + +__at_align__ float in_out_ptr0[16] = {0.0}; +#endif extern "C" void __avx_chk_kernel() { auto tmp0 = at::vec::Vectorized(1); From 108adbc726553ddba005c42ac8d511212bec8c32 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 14 Jun 2024 15:25:24 -0700 Subject: [PATCH 051/171] [dynamo][side effects] Raise assertion error if the object is already tracked for mutation (#128590) This issue was pointed out by @tombousso here - https://github.com/pytorch/pytorch/pull/128269#issuecomment-2163755792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128590 Approved by: https://github.com/mlazos ghstack dependencies: #128715, #128269 --- torch/_dynamo/side_effects.py | 8 ++++++++ torch/_dynamo/variables/builder.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 370f62929ecd86..28ce9811b4c384 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -219,6 +219,14 @@ def _track_obj( ): """Start tracking a new variable for mutation""" assert variable.source is not None + + if id(item) in self.id_to_variable: + raise AssertionError( + "Variable is already tracked for mutation. This could be " + "because you are not using VariableBuilder to construct " + "the variable tracker." + ) + variable.mutable_local = mutable_cls(variable.source) self.id_to_variable[id(item)] = variable self.keepalive.append(item) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8fe9154c53a644..f36f53b6537aa0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -315,7 +315,10 @@ def __call__(self, value): vt = self._wrap(value) vt.source = self.source - if self._can_lift_attrs_to_inputs(vt): + if ( + self._can_lift_attrs_to_inputs(vt) + and value not in self.tx.output.side_effects + ): vt = self.tx.output.side_effects.track_object_existing(value, vt) self.tx.output.variable_tracker_cache.add(value, self.source, vt) From 4ccbf711e2f36ed5c469ea0bb8957acf04d09209 Mon Sep 17 00:00:00 2001 From: Sahdev Zala Date: Sat, 15 Jun 2024 05:30:33 +0000 Subject: [PATCH 052/171] Learning Rate Scheduler docstring fix (#128679) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix docstrings in Learning Rate Scheduler. The fix can be verified by running pydocstyle path-to-file --count Related #112593 **BEFORE the PR:** pydocstyle torch/optim/lr_scheduler.py --count
 92
 **AFTER the PR:** pydocstyle torch/optim/lr_scheduler.py --count
 0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128679 Approved by: https://github.com/janeyx99 --- torch/optim/lr_scheduler.py | 190 +++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 79 deletions(-) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 4a5f162a0b2040..ef197b65abb461 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +r"""Learning Rate Scheduler.""" import math import types import warnings @@ -55,7 +56,7 @@ def _check_verbose_deprecated_warning(verbose): - """Raises a warning when verbose is not the default value.""" + """Raise a warning when verbose is not the default value.""" if verbose != "deprecated": warnings.warn( "The verbose parameter is deprecated. Please use get_last_lr() " @@ -85,9 +86,13 @@ def _copy(_param): class LRScheduler: + r"""Adjusts the learning rate during optimization.""" + _get_lr_called_within_step: bool = False - def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated"): + def __init__( + self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" + ): # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -139,12 +144,12 @@ def wrapper(*args, **kwargs): self._initial_step() def _initial_step(self): - """Initialize step counts and performs a step""" + """Initialize step counts and perform a step.""" self._step_count = 0 self.step() def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -154,7 +159,7 @@ def state_dict(self): } def load_state_dict(self, state_dict: Dict[str, Any]): - """Loads the schedulers state. + """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned @@ -167,7 +172,7 @@ def get_last_lr(self) -> List[float]: return self._last_lr def get_lr(self) -> List[float]: - # Compute learning rate using chainable form of the scheduler + """Compute learning rate using chainable form of the scheduler.""" raise NotImplementedError def print_lr( @@ -199,6 +204,7 @@ def print_lr( ) def step(self, epoch: Optional[int] = None): + """Perform a step.""" # Raise a warning if old pattern is detected # https://github.com/pytorch/pytorch/issues/20124 if self._step_count == 1: @@ -278,7 +284,9 @@ def __exit__(self, type, value, traceback): class LambdaLR(LRScheduler): - """Sets the learning rate of each parameter group to the initial lr + """Sets the initial learning rate. + + The learning rate of each parameter group is set to the initial lr times a given function. When last_epoch=-1, sets initial lr as lr. Args: @@ -312,7 +320,7 @@ def __init__( lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.optimizer = optimizer self.lr_lambdas: List[Callable[[int], float]] @@ -327,7 +335,7 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -336,7 +344,6 @@ def state_dict(self): When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. """ - state_dict = { key: value for key, value in self.__dict__.items() @@ -351,7 +358,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - """Loads the schedulers state. + """Load the scheduler's state. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. @@ -359,7 +366,6 @@ def load_state_dict(self, state_dict): state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ - lr_lambdas = state_dict.pop("lr_lambdas") self.__dict__.update(state_dict) # Restore state_dict keys in order to prevent side effects @@ -371,6 +377,7 @@ def load_state_dict(self, state_dict): self.lr_lambdas[idx].__dict__.update(fn) def get_lr(self): + """Compute learning rate.""" _warn_get_lr_called_within_step(self) return [ @@ -380,8 +387,9 @@ def get_lr(self): class MultiplicativeLR(LRScheduler): - """Multiply the learning rate of each parameter group by the factor given - in the specified function. When last_epoch=-1, sets initial lr as lr. + """Multiply the learning rate of each parameter group by the factor given in the specified function. + + When last_epoch=-1, set initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -412,7 +420,7 @@ def __init__( lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.optimizer = optimizer self.lr_lambdas: List[Callable[[int], float]] @@ -427,7 +435,7 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -448,7 +456,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - """Loads the schedulers state. + """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned @@ -465,6 +473,7 @@ def load_state_dict(self, state_dict): self.lr_lambdas[idx].__dict__.update(fn) def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if self.last_epoch > 0: @@ -477,10 +486,10 @@ def get_lr(self): class StepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma every - step_size epochs. Notice that such decay can happen simultaneously with - other changes to the learning rate from outside this scheduler. When - last_epoch=-1, sets initial lr as lr. + """Decays the learning rate of each parameter group by gamma every step_size epochs. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -516,12 +525,13 @@ def __init__( gamma=0.1, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.step_size = step_size self.gamma = gamma super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): @@ -536,10 +546,10 @@ def _get_closed_form_lr(self): class MultiStepLR(LRScheduler): - """Decays the learning rate of each parameter group by gamma once the - number of epoch reaches one of the milestones. Notice that such decay can - happen simultaneously with other changes to the learning rate from outside - this scheduler. When last_epoch=-1, sets initial lr as lr. + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -574,12 +584,13 @@ def __init__( gamma=0.1, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.milestones = Counter(milestones) self.gamma = gamma super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if self.last_epoch not in self.milestones: @@ -598,8 +609,9 @@ def _get_closed_form_lr(self): class ConstantLR(LRScheduler): - """Multiply the learning rate of each parameter group by a small constant factor until the - number of epoch reaches a pre-defined milestone: total_iters. + """Multiply the learning rate of each parameter group by a small constant factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. Notice that such multiplication of the small constant factor can happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. @@ -639,7 +651,7 @@ def __init__( total_iters=5, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 if factor > 1.0 or factor < 0: raise ValueError( "Constant multiplicative factor expected to be between 0 and 1." @@ -650,6 +662,7 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if self.last_epoch == 0: @@ -671,8 +684,9 @@ def _get_closed_form_lr(self): class LinearLR(LRScheduler): - """Decays the learning rate of each parameter group by linearly changing small - multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor. + + The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. @@ -716,7 +730,7 @@ def __init__( total_iters=5, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 if start_factor > 1.0 or start_factor <= 0: raise ValueError( "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." @@ -733,6 +747,7 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate.""" _warn_get_lr_called_within_step(self) if self.last_epoch == 0: @@ -771,6 +786,7 @@ def _get_closed_form_lr(self): class ExponentialLR(LRScheduler): """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. Args: @@ -787,11 +803,12 @@ class ExponentialLR(LRScheduler): def __init__( self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" - ): + ): # noqa: D107 self.gamma = gamma super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if self.last_epoch == 0: @@ -803,9 +820,10 @@ def _get_closed_form_lr(self): class SequentialLR(LRScheduler): - """Receives the list of schedulers that is expected to be called sequentially during - optimization process and milestone points that provides exact intervals to reflect - which scheduler is supposed to be called at a given epoch. + """Contains a list of schedulers expected to be called sequentially during the optimization process. + + Specifically, the schedulers will be called according to the milestone points, which should provide exact + intervals by which each scheduler should be called at a given epoch. Args: optimizer (Optimizer): Wrapped optimizer. @@ -842,7 +860,7 @@ def __init__( milestones: List[int], last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 if len(schedulers) < 1: raise ValueError( f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." @@ -892,6 +910,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() def step(self): + """Perform a step.""" self.last_epoch += 1 idx = bisect_right(self._milestones, self.last_epoch) scheduler = self._schedulers[idx] @@ -903,7 +922,7 @@ def step(self): self._last_lr = scheduler.get_last_lr() def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -922,7 +941,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - """Loads the schedulers state. + """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned @@ -939,8 +958,9 @@ def load_state_dict(self, state_dict): class PolynomialLR(LRScheduler): - """Decays the learning rate of each parameter group using a polynomial function - in the given total_iters. When last_epoch=-1, sets initial lr as lr. + """Decays the learning rate of each parameter group using a polynomial function in the given total_iters. + + When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -975,12 +995,13 @@ def __init__( power=1.0, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.total_iters = total_iters self.power = power super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the learning rate.""" _warn_get_lr_called_within_step(self) if self.last_epoch == 0 or self.last_epoch > self.total_iters: @@ -1004,8 +1025,9 @@ def _get_closed_form_lr(self): class CosineAnnealingLR(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr and + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: .. math:: @@ -1054,12 +1076,13 @@ def __init__( eta_min=0, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 self.T_max = T_max self.eta_min = eta_min super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Retrieve the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) if self.last_epoch == 0: @@ -1097,9 +1120,10 @@ def _get_closed_form_lr(self): class ChainedScheduler(LRScheduler): - """Chains list of learning rate schedulers. It takes a sequence of chainable learning - rate schedulers and performs consecutive step() functions belonging to them by just - one call. + """Chains a list of learning rate schedulers. + + Takes in a sequence of chainable learning rate schedulers and calls their + step() functions consecutively in just one call to step(). Args: schedulers (sequence): sequence of chained schedulers. @@ -1124,7 +1148,7 @@ class ChainedScheduler(LRScheduler): def __init__( self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None - ): + ): # noqa: D107 if len(schedulers) < 1: raise ValueError( f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." @@ -1155,6 +1179,7 @@ def __init__( ] def step(self): + """Perform a step.""" for scheduler in self._schedulers: scheduler.step() self._last_lr = [ @@ -1162,7 +1187,7 @@ def step(self): ] def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -1181,7 +1206,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - """Loads the schedulers state. + """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned @@ -1199,6 +1224,7 @@ def load_state_dict(self, state_dict): class ReduceLROnPlateau(LRScheduler): """Reduce learning rate when a metric has stopped improving. + Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number @@ -1269,7 +1295,7 @@ def __init__( min_lr: Union[List[float], float] = 0, eps=1e-8, verbose="deprecated", - ): + ): # noqa: D107 if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor @@ -1308,12 +1334,13 @@ def __init__( self._reset() def _reset(self): - """Resets num_bad_epochs counter and cooldown counter.""" + """Reset num_bad_epochs counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 self.num_bad_epochs = 0 def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] + """Perform a step.""" # convert `metrics` to float, in case it's a zero-dim Tensor current = float(metrics) if epoch is None: @@ -1347,10 +1374,10 @@ def _reduce_lr(self, epoch): param_group["lr"] = new_lr @property - def in_cooldown(self): + def in_cooldown(self): # noqa: D102 return self.cooldown_counter > 0 - def is_better(self, a, best): + def is_better(self, a, best): # noqa: D102 if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon @@ -1380,12 +1407,13 @@ def _init_is_better(self, mode, threshold, threshold_mode): self.threshold = threshold self.threshold_mode = threshold_mode - def state_dict(self): + def state_dict(self): # noqa: D102 return { key: value for key, value in self.__dict__.items() if key != "optimizer" } def load_state_dict(self, state_dict): + """Load the scheduler's state.""" self.__dict__.update(state_dict) self._init_is_better( mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode @@ -1393,10 +1421,10 @@ def load_state_dict(self, state_dict): class CyclicLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to - cyclical learning rate policy (CLR). The policy cycles the learning - rate between two boundaries with a constant frequency, as detailed in - the paper `Cyclical Learning Rates for Training Neural Networks`_. + r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). + + The policy cycles the learning rate between two boundaries with a constant frequency, + as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis. @@ -1507,7 +1535,7 @@ def __init__( max_momentum=0.9, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -1585,6 +1613,7 @@ def _init_scale_fn(self): self.scale_mode = "iterations" def scale_fn(self, x) -> float: + """Get the scaling policy.""" if self._scale_fn_custom is not None: return self._scale_fn_custom(x) else: @@ -1603,13 +1632,13 @@ def _exp_range_scale_fn(gamma: float, x: float) -> float: return gamma**x def get_lr(self): - """Calculates the learning rate at batch index. This function treats - `self.last_epoch` as the last batch index. + """Calculate the learning rate at batch index. + + This function treats `self.last_epoch` as the last batch index. If `self.cycle_momentum` is ``True``, this function has a side effect of updating the optimizer's momentum. """ - _warn_get_lr_called_within_step(self) cycle = math.floor(1 + self.last_epoch / self.total_size) @@ -1649,7 +1678,7 @@ def get_lr(self): return lrs - def state_dict(self): + def state_dict(self): # noqa: D102 state = super().state_dict() # We are dropping the `_scale_fn_ref` attribute because it is a # `weakref.WeakMethod` and can't be pickled. @@ -1664,6 +1693,7 @@ def state_dict(self): return state def load_state_dict(self, state_dict): + """Load the scheduler's state.""" fn = state_dict.pop("_scale_fn_custom") super().load_state_dict(state_dict) if fn is not None: @@ -1672,8 +1702,9 @@ def load_state_dict(self, state_dict): class CosineAnnealingWarmRestarts(LRScheduler): - r"""Set the learning rate of each parameter group using a cosine annealing - schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + r"""Set the learning rate of each parameter group using a cosine annealing schedule. + + The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the number of epochs since the last restart and :math:`T_{i}` is the number of epochs between two warm restarts in SGDR: @@ -1712,7 +1743,7 @@ def __init__( eta_min=0, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): @@ -1729,6 +1760,7 @@ def __init__( super().__init__(optimizer, last_epoch, verbose) def get_lr(self): + """Compute the initial learning rate.""" _warn_get_lr_called_within_step(self) return [ @@ -1740,7 +1772,7 @@ def get_lr(self): ] def step(self, epoch=None): - """Step could be called after every batch update + """Step could be called after every batch update. Example: >>> # xdoctest: +SKIP("Undefined vars") @@ -1766,7 +1798,6 @@ def step(self, epoch=None): >>> scheduler.step(26) >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) """ - if epoch is None and self.last_epoch < 0: epoch = 0 @@ -1814,11 +1845,11 @@ class _SchedulePhase(TypedDict): class OneCycleLR(LRScheduler): - r"""Sets the learning rate of each parameter group according to the - 1cycle learning rate policy. The 1cycle policy anneals the learning - rate from an initial learning rate to some maximum learning rate and then - from that maximum learning rate to some minimum learning rate much lower - than the initial learning rate. + r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. + + The 1cycle policy anneals the learning rate from an initial learning rate to some maximum + learning rate and then from that maximum learning rate to some minimum learning rate much + lower than the initial learning rate. This policy was initially described in the paper `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates`_. @@ -1937,7 +1968,7 @@ def __init__( three_phase=False, last_epoch=-1, verbose="deprecated", - ): + ): # noqa: D107 # Validate optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -2068,16 +2099,17 @@ def _anneal_func(self, *args, **kwargs): @staticmethod def _annealing_cos(start, end, pct): - "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" cos_out = math.cos(math.pi * pct) + 1 return end + (start - end) / 2.0 * cos_out @staticmethod def _annealing_linear(start, end, pct): - "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" return (end - start) * pct + start def get_lr(self): + """Compute the learning rate of each parameter group.""" _warn_get_lr_called_within_step(self) lrs = [] From 472211c97ad418677fd6c8ddbeb4e260128bdf47 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 14 Jun 2024 20:01:15 -0700 Subject: [PATCH 053/171] Make assert_size_stride to return all errors (#128764) This will help debug some problems I'm encountering, but in general, it is best to show the entire error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128764 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 8 ++++++++ torch/csrc/dynamo/guards.cpp | 14 +++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 221d826aa11214..c47552fc1b2a76 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10463,6 +10463,14 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_assert_size_stride(self): + x = torch.randn(2, 3, 4) + with self.assertRaisesRegex( + AssertionError, + "expected size 2==5, stride 12==9 at dim=0; expected size 3==6, stride 4==9 at dim=1; expected size 4==7, stride 1==10 at dim=2", + ): + torch._C._dynamo.guards.assert_size_stride(x, (5, 6, 7), (9, 9, 10)) + def test_module_dunder_dict(self): class MyModule(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index a9359f09b27c27..d7569f580fa651 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -671,6 +671,8 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions"); return nullptr; } + std::stringstream msg; + int num_errors = 0; for (auto i : c10::irange(ndim)) { int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i)); int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i)); @@ -679,13 +681,19 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { if (want_size != actual_size || // ignore stride differences when size is 1 (want_stride != actual_stride && actual_size > 1)) { - std::stringstream msg; + if (num_errors > 0) + msg << "; "; msg << "expected size " << actual_size << "==" << want_size << ", stride " << actual_stride << "==" << want_stride << " at dim=" << i; - PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); - return nullptr; + num_errors++; } } + + if (num_errors) { + PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); + return nullptr; + } + Py_RETURN_TRUE; } From e4c32d14a8d63ba6cdd34431dc66a166c4802298 Mon Sep 17 00:00:00 2001 From: cyy Date: Sat, 15 Jun 2024 06:38:40 +0000 Subject: [PATCH 054/171] [3/N] Remove inclusion of c10/util/string_utils.h (#128504) Follows #128372 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128504 Approved by: https://github.com/malfet --- caffe2/core/common.h | 2 -- caffe2/serialize/inline_container.cc | 6 +++--- test/cpp/jit/test_custom_class_registrations.cpp | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/caffe2/core/common.h b/caffe2/core/common.h index 40c25b7c755201..5058fdff5860a5 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -28,8 +28,6 @@ #include -#include "c10/util/string_utils.h" - namespace caffe2 { // Using statements for common classes that we refer to in caffe2 very often. diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 83415da0a4f773..2761147cf333da 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -213,9 +213,9 @@ void PyTorchStreamReader::init() { if (version_ < static_cast(kMinSupportedFileFormatVersion)) { CAFFE_THROW( "Attempted to read a PyTorch file with version ", - c10::to_string(version_), + std::to_string(version_), ", but the minimum supported version for reading is ", - c10::to_string(kMinSupportedFileFormatVersion), + std::to_string(kMinSupportedFileFormatVersion), ". Your PyTorch script module file is too old. Please regenerate it", " with latest version of PyTorch to mitigate this issue."); } @@ -733,7 +733,7 @@ void PyTorchStreamWriter::writeEndOfFile() { auto allRecords = getAllWrittenRecords(); // If no ".data/version" or "version" record in the output model, rewrites version info if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) { - std::string version = c10::to_string(version_); + std::string version = std::to_string(version_); version.push_back('\n'); if (version_ >= 0x6L) { writeRecord(".data/version", version.c_str(), version.size()); diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 819d5495b06c38..3c1ddf177e9110 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -140,7 +140,7 @@ struct TensorQueue : torch::CustomClassHolder { for (const auto index : c10::irange(queue_size)) { at::Tensor val; - queue_[index] = dict.at(key + "/" + c10::to_string(index)); + queue_[index] = dict.at(key + "/" + std::to_string(index)); queue_.push_back(val); } } @@ -152,7 +152,7 @@ struct TensorQueue : torch::CustomClassHolder { dict.insert( key + "/size", torch::tensor(static_cast(queue_.size()))); for (const auto index : c10::irange(queue_.size())) { - dict.insert(key + "/" + c10::to_string(index), queue_[index]); + dict.insert(key + "/" + std::to_string(index), queue_[index]); } return dict; } From b50c0e94c25935c3976dbef8eb07fe3d6b2fea8d Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 15 Jun 2024 07:40:18 +0000 Subject: [PATCH 055/171] TCPStoreLibUvBackend: use somaxconn and enable TCP_NODELAY (#128739) This adjusts the settings of the libuv backend to match the older TCPStore. * DEFAULT_BACKLOG: setting this to -1 will enable using the host somaxconn value instead of a hardcoded 16k value. When going over this limit with `tcp_abort_on_overflow` set it results in connections being reset. * TCP_NODELAY: Since TCPStore primarily sends small messages there's no benefit to using Nargle's algorithm and it may add additional latency for store operations. Test plan: ``` python test/distributed/test_store.py -v -k LibUv ``` Benchmark script: ``` import time import os import torch.distributed as dist rank = int(os.environ["RANK"]) store = dist.TCPStore( host_name="", port=29500, world_size=2, is_master=(rank == 0), use_libuv=True, ) if rank == 1: total_iters = 0 total_dur = 0 for iter in range(10): iters = 500000 start = time.perf_counter() for i in range(iters): store.set(f"key_{i}", f"value_{i}") dur = time.perf_counter() - start print(f"{iter}. {iters} set, qps = {iters/dur}") total_iters += iters total_dur += dur print(f"overall qps = {total_iters/total_dur}") else: print("sleeping") time.sleep(1000000000) ``` Performance seems to be negligible difference between TCP_NODELAY and not for a single host Pull Request resolved: https://github.com/pytorch/pytorch/pull/128739 Approved by: https://github.com/rsdcastro, https://github.com/kurman, https://github.com/c-p-i-o --- torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 5890ff5d95ae5e..cad6f66cebc86e 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -31,7 +31,10 @@ Other callbacks don't provide exception safety so avoid there. */ -#define DEFAULT_BACKLOG 16384 +// This controls how many un-accepted TCP connections can be waiting in the +// backlog. This should be at least world size to avoid issues on init. We set +// it to -1 to use the host max value which is controlled by `soconnmax`. +#define DEFAULT_BACKLOG -1 #define MAX_KEY_COUNT (128 * 1024) #define MAX_STRING_LEN (8 * 1024) #define MAX_PAYLOAD_LEN (8 * 1024 * 1024) @@ -134,6 +137,11 @@ class UvTcpSocket : public UvHandle { public: explicit UvTcpSocket(uv_loop_t* loop) { uv_tcp_init(loop, &client); + if (int err = uv_tcp_nodelay(&client, 1)) { + C10D_WARNING( + "The no-delay option cannot be enabled for the client socket. err={}", + err); + } } void startRead() { From de4f379cf29bc3189c5840ee999340d801f1e949 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 15 Jun 2024 09:04:06 +0000 Subject: [PATCH 056/171] run mkldnn test with inlining (#128749) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128749 Approved by: https://github.com/anijain2305 --- test/inductor/test_mkldnn_pattern_matcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 88d0012913ac62..a80d7239876028 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2526,6 +2526,7 @@ def forward(self, x): om(*example_inputs) om(*example_inputs) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_reproduce_121253_issue(self): class Mod(torch.nn.Module): def __init__(self, weight, bias, beta, alpha): @@ -2550,8 +2551,8 @@ def forward(self, x): else "mkldnn._linear_pointwise" ) for beta, alpha in zip([1.0, 0.1, 0.0], [1.0, 0.1, 1.0]): - weight = torch.randn(64, 64, dtype=dtype) - bias = torch.randn(64, dtype=dtype) + weight = torch.nn.Parameter(torch.randn(64, 64, dtype=dtype)) + bias = torch.nn.Parameter(torch.randn(64, dtype=dtype)) mod = Mod(weight, bias, beta, alpha).to(dtype).eval() with torch.no_grad(): x = torch.randn(1, 64, dtype=dtype) From 60bbdc0b40656cf70b2b098c7d715e19f031fb0d Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Fri, 14 Jun 2024 08:07:19 +0000 Subject: [PATCH 057/171] Modularize aten parameter parser and checker (#125308) In this PR, we abstracted the different types of aten operation parameters as `ParameterMetadata`. This structure intends to be used to represent and store the metadata of each aten operation parameter. Currently, it only supports `Tensor`, `TensorList`, and `Scalar`. ```C++ using ParameterMetadataValue = std::variant, c10::Scalar>; ``` With this PR, we can extend other parameter-type support in a more modularize way, like `string`, `int`, `double`, and other different types to be summarized as the following list. The list is collected from all aten operations and ordered by the number of being used. - `Tensor` - `bool` - `int64_t` - `TensorList` - `Scalar` - `c10::SymIntArrayRef` - `::std::optional` - `IntArrayRef` - `double` - `c10::SymInt` - `::std::optional` - `::std::optional` - `::std::optional` - `::std::optional` - `::std::optional` - `::std::optional` - `Dimname` - `::std::optional` - `c10::string_view` - `::std::optional` - `OptionalIntArrayRef` - `::std::optional` - `OptionalSymIntArrayRef` - `::std::optional` - `::std::optional` - `ScalarType` - `ArrayRef` - `DimnameList` - `::std::optional>` - `::std::array` - `::std::optional` - `c10::List<::std::optional>` - `::std::array` - `Storage` - `::std::array` - `Device` - `DeviceIndex` - `ITensorListRef` - `Stream` - `Layout` - `MemoryFormat` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125308 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 20 +- torch/_C/__init__.pyi.in | 1 + torch/_inductor/utils.py | 82 +++-- torch/csrc/dynamo/guards.cpp | 51 ++- torch/csrc/dynamo/guards.h | 12 +- .../inductor/aoti_eager/kernel_holder.cpp | 330 ++++++++---------- .../csrc/inductor/aoti_eager/kernel_holder.h | 44 ++- .../inductor/aoti_eager/kernel_meta_info.cpp | 182 +++++++--- .../inductor/aoti_eager/kernel_meta_info.h | 91 ++++- torch/csrc/utils/python_dispatch.cpp | 1 + 10 files changed, 502 insertions(+), 312 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d3eac170c32934..7c66547400d68c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -769,7 +769,7 @@ def fn(a, b): ) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_support_out(self): + def test_aoti_eager_support_out(self): ns = "aten" op_name = "clamp" dispatch_key = "CPU" @@ -821,7 +821,7 @@ def test_eager_aoti_support_out(self): self.assertEqual(ref_out_tensor1, res_out_tensor1) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_cache_hit(self): + def test_aoti_eager_cache_hit(self): ns = "aten" op_name = "abs" dispatch_key = "CPU" @@ -862,7 +862,7 @@ def test_eager_aoti_cache_hit(self): self.assertEqual(ref_value, res_value) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_with_persistent_cache(self): + def test_aoti_eager_with_persistent_cache(self): def fn(a): return torch.abs(a) @@ -906,7 +906,7 @@ def fn(a): self.assertTrue(kernel_lib_path in kernel_libs_abs_path) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_with_scalar(self): + def test_aoti_eager_with_scalar(self): namespace_name = "aten" op_name = "add" op_overload_name = "Tensor" @@ -942,18 +942,18 @@ def test_eager_aoti_with_scalar(self): self.assertTrue(isinstance(op_info, dict)) self.assertTrue("meta_info" in op_info) self.assertTrue(len(op_info["meta_info"]) == 3) + # Scalar Tensor + self.assertTrue("scalar_value" not in op_info["meta_info"][0]) self.assertTrue(op_info["meta_info"][0]["sizes"] == []) self.assertTrue(op_info["meta_info"][0]["strides"] == []) # Scalar Tensor - self.assertTrue("scalar_value" not in op_info["meta_info"][0]) + self.assertTrue("scalar_value" not in op_info["meta_info"][1]) self.assertTrue(op_info["meta_info"][1]["sizes"] == []) self.assertTrue(op_info["meta_info"][1]["strides"] == []) - # Scalar Tensor - self.assertTrue("scalar_value" not in op_info["meta_info"][1]) - self.assertTrue(op_info["meta_info"][2]["sizes"] == []) - self.assertTrue(op_info["meta_info"][2]["strides"] == []) # Scalar self.assertTrue("scalar_value" in op_info["meta_info"][2]) + self.assertTrue("sizes" not in op_info["meta_info"][2]) + self.assertTrue("strides" not in op_info["meta_info"][2]) with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: a = torch.randn(128, device=device) @@ -976,7 +976,7 @@ def test_eager_aoti_with_scalar(self): self.assertEqual(ref_values, res_values) @skipCUDAIf(not SM80OrLater, "Requires sm80") - def test_eager_aoti_override_registration(self): + def test_aoti_eager_override_registration(self): namespace_name = "aten" dispatch_key = "CPU" device = torch.device("cpu") diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bd1cfddd519e36..f37c3a326fda5e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1568,6 +1568,7 @@ class DispatchKeySet: def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ... def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ... def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ... + def raw_repr(self) -> _int: ... def highestPriorityTypeId(self) -> DispatchKey: ... def has(self, k: _dispatchkey) -> _bool: ... def add(self, k: _dispatchkey) -> DispatchKeySet: ... diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 51216e09522572..129ea8c6a45f5b 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -46,7 +46,6 @@ import sympy import torch -import torch.utils._pytree as pytree from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import detect_fake_mode from torch.autograd import DeviceType @@ -1683,12 +1682,22 @@ def aoti_compile_with_persistent_cache( type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} supported_scalar_types = tuple(type_to_torch_dtype.keys()) - flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + flattened_inputs = list(args) + list(kwargs.values()) if not all( - isinstance(input, (supported_scalar_types, torch.Tensor)) + isinstance(input, (supported_scalar_types, torch.Tensor, list)) for input in flattened_inputs ): - raise NotImplementedError("Only support tensor, int, float, bool for now") + raise NotImplementedError( + "Only support tensor, tensor list, int, float, bool for now" + ) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + raise NotImplementedError( + "Regarding list, _impl_with_aoti_compile only support tensor list now." + ) persistent_cache = aoti_eager_cache_dir(ns, device_type) if not persistent_cache.exists(): @@ -1718,30 +1727,59 @@ def aoti_compile_with_persistent_cache( ) kernel_metadata_items = [] - for input in flattened_inputs: - # TODO(Eikan): To add dynamic support + + def extract_tensor_metadata(input: torch.Tensor) -> Dict[str, Any]: + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + def extract_scalar_metadata( + input: Union[int, float, bool] + ) -> Dict[str, Any]: + assert isinstance(input, supported_scalar_types) metadata: Dict[str, Any] = {} metadata["is_dynamic"] = dynamic + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + def extract_tensor_list_metadata( + input: List[torch.Tensor], + ) -> Dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(item)) + metadata: Dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + for idx, input in enumerate(flattened_inputs): if isinstance(input, torch.Tensor): - metadata["device_type"] = f"{input.device.type}" - if is_cpu_device([input]): - metadata["device_index"] = -1 - else: - metadata["device_index"] = input.device.index - metadata["dtype"] = f"{input.dtype}" - metadata["sizes"] = list(input.size()) - metadata["strides"] = list(input.stride()) + metadata = extract_tensor_metadata(input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(input) else: - assert isinstance(input, supported_scalar_types) - # Scalar tensor - metadata["device_type"] = device_type - metadata["device_index"] = -1 if device_type == "cpu" else 0 - metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" - metadata["sizes"] = [] - metadata["strides"] = [] - metadata["scalar_value"] = input + metadata = extract_scalar_metadata(input) + metadata["arg_order"] = idx kernel_metadata_items.append(metadata) kernel_meta_info: Dict[str, Any] = {} diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index d7569f580fa651..386c536a4db453 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -73,16 +73,17 @@ TensorCheck::TensorCheck( TensorCheck::TensorCheck( const LocalState& state, PyTypeObject* pt, - uint64_t dispatch_key, + c10::DispatchKeySet dispatch_key_set, at::ScalarType dtype, at::DeviceIndex device_index, + bool requires_grad, std::vector> dynamic_dims_sizes, std::vector> dynamic_dims_strides) : pytype(pt), - dispatch_key_(dispatch_key), + dispatch_key_(state.apply(dispatch_key_set).raw_repr()), dtype_(dtype), device_index_(device_index), - requires_grad_(false), + requires_grad_(requires_grad), sizes_(std::move(dynamic_dims_sizes)), strides_(std::move(dynamic_dims_strides)), dim_(static_cast(sizes_.size())) {} @@ -90,18 +91,46 @@ TensorCheck::TensorCheck( // See note in guards.py [Note - On Export Tensor Guards] // Logic parallel to here must be maintained in python bool TensorCheck::check(const LocalState& state, const at::Tensor& v) { - if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || - dtype_ != v.dtype().toScalarType() || - device_index_ != v.device().index() || - requires_grad_ != v.requires_grad()) { + // In terms of a sparse_csr tensor, it does not support strides informatio + c10::SymIntArrayRef sym_strides(std::vector(v.ndimension(), -1)); + bool does_not_support_stride = v.layout() == c10::kSparseCsr || + v.layout() == c10::kSparseCsc || v.layout() == c10::kSparseBsc || + v.layout() == c10::kSparseBsr; + if (!does_not_support_stride) { + sym_strides = v.sym_strides(); + } + + return check( + state, + v.key_set(), + v.dtype().toScalarType(), + v.device(), + v.sym_sizes(), + sym_strides, + v.requires_grad()); +} + +bool TensorCheck::check( + const LocalState& state, + const c10::DispatchKeySet& dispatch_key_set, + const at::ScalarType& dtype, + const c10::Device& device, + const c10::SymIntArrayRef& sym_sizes, + const c10::SymIntArrayRef& sym_strides, + const bool& requires_grad) { + if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() || + dtype_ != dtype || device_index_ != device.index() || + requires_grad_ != requires_grad) { return false; } - auto ndim = v.ndimension(); - if (ndim != dim_) { + + auto ndim = sym_sizes.size(); + if (ndim != static_cast(dim_)) { return false; } - const auto& sizes = v.sym_sizes(); - const auto& strides = v.sym_strides(); + + const auto& sizes = sym_sizes; + const auto& strides = sym_strides; for (auto i : c10::irange(ndim)) { auto known_size = sizes_[i]; auto known_stride = strides_[i]; diff --git a/torch/csrc/dynamo/guards.h b/torch/csrc/dynamo/guards.h index 26accf742181ae..cc2e4b438ee707 100644 --- a/torch/csrc/dynamo/guards.h +++ b/torch/csrc/dynamo/guards.h @@ -28,6 +28,7 @@ struct LocalState { LocalState() : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), + override_dispatch_key_set(c10::BackendComponent::InvalidBit), grad_mode_enabled(at::GradMode::is_enabled()) {} void overrideDispatchKeySet(c10::DispatchKeySet ks) { @@ -47,13 +48,22 @@ class TensorCheck { TensorCheck( const LocalState& state, PyTypeObject* pt, - uint64_t dispatch_key, + c10::DispatchKeySet dispatch_key_set, at::ScalarType dtype, at::DeviceIndex device_index, + bool requires_grad, std::vector> dynamic_dims_sizes, std::vector> dynamic_dims_strides); bool check(const LocalState& state, const at::Tensor& v); + bool check( + const LocalState& state, + const c10::DispatchKeySet& dispatch_key_set, + const at::ScalarType& dtype, + const c10::Device& device, + const c10::SymIntArrayRef& dynamic_dims_sizes, + const c10::SymIntArrayRef& dynamic_dims_strides, + const bool& requires_grad); std::string check_verbose( const LocalState& state, const at::Tensor& v, diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index 10b7daff5bfcbd..cda50f077e5723 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -92,57 +92,71 @@ bool unpack_ivalue( return true; } -bool unpack_tensors( +std::vector unpack_tensors( const std::vector& arguments, const torch::jit::Stack& stack, - const c10::Device& device, - std::vector& inputs, - bool with_scalar = false) { + const c10::Device& device) { + std::vector inputs; for (size_t idx = 0; idx < stack.size(); idx++) { - if (!with_scalar && stack[idx].isScalar()) { - continue; - } - - if (!unpack_ivalue(arguments[idx], stack[idx], device, inputs)) { - return false; + auto ivalue = stack[idx]; + auto ivalue_arg = arguments[idx]; + if (ivalue.isTensor()) { + unpack_tensor_ivalue(ivalue, device, inputs); + } else if (ivalue.isTensorList()) { + unpack_tensor_list_ivalue(ivalue, device, inputs); + } else if (ivalue.isOptionalTensorList()) { + unpack_optional_tensor_list_ivalue(ivalue, device, inputs); + } else if ( + *ivalue_arg.real_type() == + *c10::getTypePtr>()) { + // ivalue is c10::optional + unpack_optional_tensor_ivalue(ivalue, device, inputs); } } - - return true; + return inputs; } -std::vector get_tensor_parameter_index( +std::vector unpack_input_parameters( const std::vector& arguments, const torch::jit::Stack& stack) { - std::vector tensor_parameter_index; + std::vector inputs_metadata; for (size_t idx = 0; idx < stack.size(); idx++) { - if (stack[idx].isScalar() || stack[idx].isTensor()) { - // scalar and tensor - tensor_parameter_index.push_back(idx); + if (stack[idx].isScalar()) { + // scalar + inputs_metadata.push_back(ParameterMetadata(stack[idx].toScalar(), idx)); } else if (stack[idx].isTensorList()) { // tensor list - std::fill_n( - std::back_inserter(tensor_parameter_index), - stack[idx].toListRef().size(), - idx); + inputs_metadata.push_back( + ParameterMetadata(stack[idx].toTensorList().vec(), idx)); } else if (stack[idx].isOptionalTensorList()) { // optional tensor list: std::vector> + std::vector tensor_list; for (const auto& item : stack[idx].toListRef()) { if (item.toOptional().has_value()) { - tensor_parameter_index.push_back(idx); + tensor_list.push_back(item.toOptional().value()); } } + inputs_metadata.push_back(ParameterMetadata(tensor_list, idx)); } else if ( *arguments[idx].real_type() == *c10::getTypePtr>()) { // optional tensor if (stack[idx].toOptional().has_value()) { - tensor_parameter_index.push_back(idx); + inputs_metadata.push_back(ParameterMetadata( + stack[idx].toOptional().value(), idx)); } + } else if (stack[idx].isTensor()) { + inputs_metadata.push_back(ParameterMetadata(stack[idx].toTensor(), idx)); + } else { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "Not implemented for operations that contain a parameter which is ", + "not one of the following types: at::Tensor, at::TensorList, ", + "std::optional, std::vector> and c10::Scalar."); } } - return tensor_parameter_index; + return inputs_metadata; } } // namespace @@ -167,9 +181,9 @@ void AOTIPythonKernelHolder::operator()( const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) { - AOTIKernelState kernel_state; - if (cache_lookup(op, keyset, stack, kernel_state)) { - cache_hit(kernel_state, op, keyset, stack); + AOTIKernelMetadata aoti_kernel_metadata; + if (cache_lookup(op, keyset, stack, aoti_kernel_metadata)) { + cache_hit(aoti_kernel_metadata, op, keyset, stack); } else { cache_miss(op, keyset, stack); } @@ -179,7 +193,7 @@ bool AOTIPythonKernelHolder::cache_lookup( const c10::OperatorHandle& op, const c10::DispatchKeySet& keyset, const torch::jit::Stack* stack, - AOTIKernelState& kernel_state) { + AOTIKernelMetadata& aoti_kernel_metadata) { TORCH_CHECK_NOT_IMPLEMENTED( op.schema().returns().size() == 1, "Not implemented for operations that return either multiple values or no value."); @@ -187,110 +201,32 @@ bool AOTIPythonKernelHolder::cache_lookup( op.schema().returns()[0].type()->isSubtypeOf(c10::TensorType::get()), "Not implemented for operations that return a non-Tensor value."); - std::vector inputs; - auto res = - unpack_tensors(op.schema().arguments(), *stack, device_, inputs, true); - TORCH_CHECK_NOT_IMPLEMENTED( - res && inputs.size() > 0, - "Not implemented for operations that contain a parameter which is ", - "not one of the following types: at::Tensor, at::TensorList, ", - "std::optional, std::vector>."); - - auto tensor_parameter_index = - get_tensor_parameter_index(op.schema().arguments(), *stack); - TORCH_INTERNAL_ASSERT(tensor_parameter_index.size() == inputs.size()); - auto inputs_metadata = get_inputs_metadata( - inputs, op.schema().arguments(), tensor_parameter_index); - auto aoti_kernel_state = aoti_kernel_cache_.find(inputs_metadata); - if (aoti_kernel_state == aoti_kernel_cache_.end()) { - return false; - } - - if (aoti_kernel_state->second.tensor_checks_.size() != inputs.size()) { - return false; - } - - torch::dynamo::LocalState local_state; - local_state.overrideDispatchKeySet(c10::DispatchKeySet(dispatch_key_)); - - for (size_t i = 0; i < inputs.size(); ++i) { - bool pass = aoti_kernel_state->second.tensor_checks_[i].check( - local_state, inputs[i]); - if (!pass) { - return false; + auto inputs_metadata = + unpack_input_parameters(op.schema().arguments(), *stack); + for (const auto& aoti_kernel_cache : aoti_kernel_cache_) { + if (aoti_kernel_cache.check(inputs_metadata)) { + aoti_kernel_metadata = aoti_kernel_cache; + return true; } } - kernel_state = aoti_kernel_state->second; - return true; + return false; } void AOTIPythonKernelHolder::cache_hit( - const AOTIKernelState& kernel_state, + const AOTIKernelMetadata& aoti_kernel_metadata, const c10::OperatorHandle& op, const c10::DispatchKeySet& keyset, torch::jit::Stack* stack) { - std::vector inputs; - unpack_tensors(op.schema().arguments(), *stack, device_, inputs); + auto inputs = unpack_tensors(op.schema().arguments(), *stack, device_); torch::jit::drop(*stack, op.schema().arguments().size()); - auto outputs = kernel_state.kernel_runner_->run(inputs); + auto outputs = aoti_kernel_metadata.kernel_runner_->run(inputs); for (auto& output : outputs) { stack->push_back(output); } } -AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata( - const std::vector& inputs, - const std::vector& inputs_argument, - const std::vector& inputs_argument_index) { - AOTIKernelMetadata inputs_metadata; - for (size_t idx = 0; idx < inputs.size(); ++idx) { - auto input = inputs[idx]; - auto input_info = inputs_argument[inputs_argument_index[idx]]; - - auto device = input.device(); - if (device.is_cpu()) { - // If the device is CPU, set the device index to -1. - device = c10::Device(device.type(), -1); - } - - c10::Scalar scalar_value((double)1.0); - auto tensor_type = input.scalar_type(); - - bool is_scalar = input_info.type()->isSubtypeOf(*c10::NumberType::get()); - if (is_scalar) { - if (c10::isFloatingType(input.scalar_type())) { - auto scalar_numeric_value = input.item().toDouble(); - tensor_type = c10::ScalarType::Double; - scalar_value = c10::Scalar(scalar_numeric_value); - } else if (c10::isIntegralType(input.scalar_type(), false)) { - auto scalar_numeric_value = input.item().toUInt64(); - tensor_type = c10::ScalarType::UInt64; - scalar_value = c10::Scalar(scalar_numeric_value); - } else if (input.scalar_type() == c10::ScalarType::Bool) { - auto scalar_numeric_value = input.item().toBool(); - tensor_type = c10::ScalarType::Bool; - scalar_value = c10::Scalar(scalar_numeric_value); - } else { - TORCH_CHECK( - false, - "Unsupported scalar tensor type: ", - c10::toString(input.scalar_type())); - } - } - - inputs_metadata.emplace_back( - false, - tensor_type, - c10::IValue(scalar_value), - device, - input.sizes().vec(), - input.strides().vec()); - } - return inputs_metadata; -} - void AOTIPythonKernelHolder::init_aoti_kernel_cache() { if (device_.type() == c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES) { return; @@ -315,8 +251,46 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { "Failed to load AOTI kernel. Operator Name is ", op_name_with_overload_); + auto build_tensor_metadata = [](const py::dict& metadata) -> TensorMetadata { + // Access the fields of each metadata dict + auto is_dynamic = metadata["is_dynamic"].cast(); + auto device_type = metadata["device_type"].cast(); + auto device_index = metadata["device_index"].cast(); + auto data_type_obj = metadata["dtype"].cast(); + TORCH_INTERNAL_ASSERT(THPDtype_Check(data_type_obj.ptr())); + auto data_type = + reinterpret_cast(data_type_obj.ptr())->scalar_type; + auto sizes = metadata["sizes"].cast>(); + auto strides = metadata["strides"].cast>(); + auto requires_grad = metadata["requires_grad"].cast(); + auto dispatch_key_set_raw_repr = + metadata["dispatch_key_set"].cast(); + auto dispatch_key_set = c10::DispatchKeySet( + c10::DispatchKeySet::RAW, dispatch_key_set_raw_repr); + auto device = c10::Device(device_type); + device.set_index(device_index); + + auto tensor_metadata = TensorMetadata( + is_dynamic, + data_type, + device, + dispatch_key_set, + sizes, + strides, + requires_grad); + + // Build guard for tensor check + torch::dynamo::LocalState state; + state.overrideDispatchKeySet(dispatch_key_set); + tensor_metadata.build_guard(state); + + return tensor_metadata; + }; + + TORCH_INTERNAL_ASSERT(py::isinstance(result)); auto kernel_info_list = result.cast(); for (auto kernel_info : kernel_info_list) { + TORCH_INTERNAL_ASSERT(py::isinstance(kernel_info)); auto item_dict = kernel_info.cast(); // Access the kernel_path field @@ -325,97 +299,70 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { // Access the meta_info list auto inputs_metadata = item_dict["meta_info"].cast(); - std::vector tensor_checks; - std::vector tensor_metadata_list; - - torch::dynamo::LocalState state; + std::vector parameter_metadata_list; // Loop over the meta_info list - for (auto item : inputs_metadata) { - // Convert the handle to a dict - auto metadata = item.cast(); - - // Access the fields of each metadata dict - auto is_dynamic = metadata["is_dynamic"].cast(); - auto device_type = metadata["device_type"].cast(); - auto device_index = metadata["device_index"].cast(); - auto data_type_obj = metadata["dtype"].cast(); - TORCH_INTERNAL_ASSERT(THPDtype_Check(data_type_obj.ptr())); - auto data_type = - reinterpret_cast(data_type_obj.ptr())->scalar_type; - auto sizes = metadata["sizes"].cast>(); - auto strides = metadata["strides"].cast>(); + for (auto item_metadata : inputs_metadata) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(py::isinstance(item_metadata)); + auto metadata = item_metadata.cast(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(metadata.contains("arg_order")); + uint64_t arg_idx = metadata["arg_order"].cast(); bool is_scalar = metadata.contains("scalar_value"); - - std::vector> sym_optional_sizes; - std::vector> sym_optional_strides; - for (int64_t size : sizes) { - sym_optional_sizes.push_back(std::optional(size)); - } - for (int64_t stride : strides) { - sym_optional_strides.push_back(std::optional(stride)); - } - - // If an input parameter is a scalar, its detailed value is cached. - // This is done to ensure correctness during subsequent checks. - c10::Scalar scalar_value((double)1.0); - if (is_scalar) { - if (c10::isFloatingType(data_type)) { - auto scalar_numeric_value = metadata["scalar_value"].cast(); - data_type = c10::ScalarType::Double; - scalar_value = c10::Scalar(scalar_numeric_value); - } else if (c10::isIntegralType(data_type, false)) { - auto scalar_numeric_value = metadata["scalar_value"].cast(); - data_type = c10::ScalarType::UInt64; - scalar_value = c10::Scalar(scalar_numeric_value); - } else if (data_type == c10::ScalarType::Bool) { - auto scalar_numeric_value = metadata["scalar_value"].cast(); - data_type = c10::ScalarType::Bool; - scalar_value = c10::Scalar(scalar_numeric_value); - } else { - TORCH_CHECK( - false, - "Unsupported scalar tensor type: ", - c10::toString(data_type)); + bool is_tensor_list = metadata.contains("tensor_list"); + + if (is_tensor_list) { + // Tensor List + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + py::isinstance(metadata["tensor_list"])); + auto tensor_list = metadata["tensor_list"].cast(); + std::vector test_list_metadata; + for (auto item_tensor : tensor_list) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + py::isinstance(item_tensor)); + auto metadata = item_tensor.cast(); + auto tensor_metadata = build_tensor_metadata(metadata); + test_list_metadata.push_back(tensor_metadata); } + parameter_metadata_list.push_back( + ParameterMetadata(test_list_metadata, arg_idx)); + } else if (is_scalar) { + // Scalar + auto metadata = item_metadata.cast(); + // Always cast scalar value to double to simplify the comparison + auto scalar_value = metadata["scalar_value"].cast(); + parameter_metadata_list.push_back( + ParameterMetadata(c10::Scalar(scalar_value), arg_idx)); + } else { + // Tensor + auto metadata = item_metadata.cast(); + auto tensor_metadata = build_tensor_metadata(metadata); + parameter_metadata_list.push_back( + ParameterMetadata(tensor_metadata, arg_idx)); } - - tensor_metadata_list.emplace_back( - is_dynamic, - data_type, - c10::IValue(scalar_value), - c10::Device(c10::Device(device_type).type(), device_index), - sizes, - strides); - tensor_checks.emplace_back( - state, - nullptr, - uint64_t(c10::DispatchKeySet(dispatch_key_).raw_repr()), - data_type, - c10::DeviceIndex(device_index), - sym_optional_sizes, - sym_optional_strides); } - AOTIKernelState aoti_kernel_state; - aoti_kernel_state.kernel_runner_ = load_aoti_model_runner(kernel_path); - aoti_kernel_state.tensor_checks_ = tensor_checks; - aoti_kernel_cache_[tensor_metadata_list] = aoti_kernel_state; + AOTIKernelMetadata aoti_kernel_metadata; + aoti_kernel_metadata.parameter_metadata_list_ = parameter_metadata_list; + aoti_kernel_metadata.kernel_runner_ = load_aoti_model_runner(kernel_path); + aoti_kernel_cache_.push_back(aoti_kernel_metadata); } } std::shared_ptr AOTIPythonKernelHolder:: load_aoti_model_runner(const std::string& so_path) { + TORCH_CHECK( + device_.type() == c10::DeviceType::CUDA || + device_.type() == c10::DeviceType::CPU, + "AOTI for eager does not support ", + c10::DeviceTypeName(device_.type()), + " now."); if (device_.type() == c10::DeviceType::CUDA) { #ifdef USE_CUDA return std::make_shared(so_path); #else return nullptr; #endif - } else if (device_.type() == c10::DeviceType::CPU) { - return std::make_shared(so_path); } else { - TORCH_WARN("Unsupported device type"); - return nullptr; + return std::make_shared(so_path); } } @@ -437,10 +384,7 @@ void AOTIPythonKernelHolder::cache_miss( #endif } - std::vector inputs; - TORCH_INTERNAL_ASSERT( - unpack_tensors(op.schema().arguments(), *stack, device_, inputs), - "Failed to unpack tensors for the stack to run the AOTI kernel."); + auto inputs = unpack_tensors(op.schema().arguments(), *stack, device_); auto outputs = kernel->run(inputs); torch::jit::drop(*stack, op.schema().arguments().size()); // TODO: Get the output type of this operation and then convert to the diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.h b/torch/csrc/inductor/aoti_eager/kernel_holder.h index b67e4e7d4464e8..5c65f0df513454 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.h +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.h @@ -14,9 +14,33 @@ namespace torch::inductor { -struct AOTIKernelState { +// Represent AOTI kernel. It contains all the parameter metadata of the kernel +// and the AOTI model runner. +struct AOTIKernelMetadata { + // Represent all the parameters of AOTI kernel + std::vector parameter_metadata_list_; + // AOTI model runner to run the AOTI kernel std::shared_ptr kernel_runner_; - std::vector tensor_checks_; + AOTIKernelMetadata() : parameter_metadata_list_(), kernel_runner_(nullptr) {} + + // Check whether the given parameter metadata list is the same as the + // parameter metadata list of the AOTI kernel. + bool check( + const std::vector& parameter_metadata_list) const { + if (parameter_metadata_list_.size() != parameter_metadata_list.size()) { + return false; + } + + for (size_t i = 0; i < parameter_metadata_list_.size(); ++i) { + if (parameter_metadata_list_[i] == parameter_metadata_list[i]) { + continue; + } else { + return false; + } + } + + return true; + } }; // The AOTIPythonKernelHolder class uses the AOT Inductor to generate a kernel @@ -38,10 +62,8 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel { // The Python interpreter to get OpOverload object with the given op_name and // op_overload_name. c10::impl::PyInterpreter* pyinterpreter_; - - std:: - unordered_map - aoti_kernel_cache_; + // Cache the produced kernels by AOTI and its metadata + std::vector aoti_kernel_cache_; public: AOTIPythonKernelHolder( @@ -59,13 +81,13 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel { const c10::OperatorHandle& op, const c10::DispatchKeySet& keyset, const torch::jit::Stack* stack, - AOTIKernelState& kernel_state); + AOTIKernelMetadata& aoti_kernel_metadata); void cache_miss( const c10::OperatorHandle& op, const c10::DispatchKeySet& keyset, torch::jit::Stack* stack); void cache_hit( - const AOTIKernelState& kernel_state, + const AOTIKernelMetadata& aoti_kernel_metadata, const c10::OperatorHandle& op, const c10::DispatchKeySet& keyset, torch::jit::Stack* stack); @@ -81,12 +103,6 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel { // the given operation. // Inductor utility function - torch._inductor.utils.load_aoti_eager_cache void init_aoti_kernel_cache(); - // Abstract the meta information of each tensor for the given operation. The - // meta infomation will be used for cache lookup as the key. - AOTIKernelMetadata get_inputs_metadata( - const std::vector& inputs, - const std::vector& inputs_argument, - const std::vector& inputs_argument_index); // Load the AOTIModelContainerRunner object from the given file path. std::shared_ptr load_aoti_model_runner( const std::string&); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index a49fab21d671e1..95cc29b412c1fa 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -6,51 +6,89 @@ namespace torch::inductor { TensorMetadata::TensorMetadata(const at::Tensor& src_tensor) : is_symbolic_(false), + dtype_(src_tensor.scalar_type()), device_(src_tensor.device()), + dispatch_key_set_(src_tensor.key_set()), sizes_(src_tensor.sizes().vec()), - strides_(src_tensor.sizes().vec()) {} + strides_(src_tensor.strides().vec()), + requires_grad_(src_tensor.requires_grad()) {} TensorMetadata::TensorMetadata( bool is_symbolic, c10::ScalarType dtype, c10::Device device, + c10::DispatchKeySet dispatch_key_set, std::vector sizes, - std::vector strides) + std::vector strides, + bool requires_grad) : is_symbolic_(is_symbolic), dtype_(dtype), - scalar_value_((float)1.0), device_(device), + dispatch_key_set_(dispatch_key_set), sizes_(sizes), - strides_(strides) { + strides_(strides), + requires_grad_(requires_grad) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_symbolic_, "Not support symbolic shape now"); } -TensorMetadata::TensorMetadata( - bool is_symbolic, - c10::ScalarType dtype, - c10::IValue scalar_value, - c10::Device device, - std::vector sizes, - std::vector strides) - : is_symbolic_(is_symbolic), - dtype_(dtype), - scalar_value_(scalar_value), - device_(device), - sizes_(sizes), - strides_(strides) { +void TensorMetadata::build_guard(const torch::dynamo::LocalState& local_state) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_symbolic_, "Not support symbolic shape now"); + std::vector> sym_sizes; + std::vector> sym_strides; + std::transform( + sizes_.begin(), + sizes_.end(), + std::back_inserter(sym_sizes), + [](int64_t size) { return std::optional(size); }); + std::transform( + strides_.begin(), + strides_.end(), + std::back_inserter(sym_strides), + [](int64_t stride) { return std::optional(stride); }); + tensor_check_ = torch::dynamo::TensorCheck( + local_state, + nullptr, + dispatch_key_set_, + dtype_, + device_.index(), + requires_grad_, + sym_sizes, + sym_strides); } bool TensorMetadata::operator==(const TensorMetadata& other) const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_symbolic_, "Not support symbolic shape now"); - return this->is_symbolic_ == other.is_symbolic_ && - this->dtype_ == other.dtype_ && - this->scalar_value_ == other.scalar_value_ && - this->device_.type() == other.device_.type() && - this->sizes_ == other.sizes_ && this->strides_ == other.strides_; + + if (tensor_check_.has_value()) { + auto sizes_ = c10::IntArrayRef(other.sizes_); + auto strides_ = c10::IntArrayRef(other.strides_); + auto sym_sizes = c10::SymIntArrayRef( + reinterpret_cast(sizes_.data()), sizes_.size()); + auto sym_strides = c10::SymIntArrayRef( + reinterpret_cast(strides_.data()), strides_.size()); + + torch::dynamo::LocalState local_state; + local_state.overrideDispatchKeySet(dispatch_key_set_); + auto _tensor_check = tensor_check_.value(); + auto res = _tensor_check.check( + local_state, + other.dispatch_key_set_, + other.dtype_, + other.device_, + sym_sizes, + sym_strides, + other.requires_grad_ /* Should we need to care about grad requirement?*/); + return res; + } else { + return this->is_symbolic_ == other.is_symbolic_ && + this->dtype_ == other.dtype_ && this->device_ == other.device_ && + this->dispatch_key_set_ == other.dispatch_key_set_ && + this->requires_grad_ == other.requires_grad_ && + this->sizes_ == other.sizes_ && this->strides_ == other.strides_; + } } std::ostream& operator<<( @@ -58,8 +96,6 @@ std::ostream& operator<<( const TensorMetadata& tensor_metadata) { stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl; stream << "dtype_: " << tensor_metadata.dtype_ << std::endl; - stream << "scalar_value_: " << tensor_metadata.scalar_value_.type()->str() - << "(" << tensor_metadata.scalar_value_ << ")" << std::endl; stream << "device_: " << tensor_metadata.device_ << std::endl; stream << "sizes_: "; for (const auto& size : tensor_metadata.sizes_) { @@ -70,37 +106,93 @@ std::ostream& operator<<( for (const auto& stride : tensor_metadata.strides_) { stream << stride << " "; } + + stream << "requires_grad_: " << tensor_metadata.requires_grad_ << std::endl; + stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ + << std::endl; + stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value() + << std::endl; stream << std::endl; return stream; } -size_t TensorMetadataHash::operator()( - const TensorMetadata& tensor_metadata) const { - auto hash = std::hash()(tensor_metadata.is_symbolic_); - hash = c10::hash_combine( - hash, std::hash()(tensor_metadata.dtype_)); - hash = - c10::hash_combine(hash, c10::IValue::hash(tensor_metadata.scalar_value_)); - hash = c10::hash_combine( - hash, std::hash()(tensor_metadata.device_.type())); - - for (auto& e : tensor_metadata.sizes_) { - hash = c10::hash_combine(hash, std::hash()(e)); +ParameterMetadata::ParameterMetadata( + TensorMetadata tensor_metadata, + uint64_t input_order) + : tag_(TENSOR), value_(tensor_metadata), order_(input_order) {} + +ParameterMetadata::ParameterMetadata( + const at::Tensor& tensor, + uint64_t input_order) + : tag_(TENSOR), order_(input_order) { + value_ = TensorMetadata(tensor); +} + +ParameterMetadata::ParameterMetadata( + const std::vector& tensor_metadata_list, + uint64_t input_order) + : tag_(TENSOR_LIST), value_(tensor_metadata_list), order_(input_order) {} + +ParameterMetadata::ParameterMetadata( + const std::vector& tensor_list, + uint64_t input_order) + : tag_(TENSOR_LIST), order_(input_order) { + std::vector tensor_metadata_list; + for (const auto& tensor : tensor_list) { + tensor_metadata_list.push_back(TensorMetadata(tensor)); } + value_ = tensor_metadata_list; +} - for (auto& e : tensor_metadata.strides_) { - hash = c10::hash_combine(hash, std::hash()(e)); +ParameterMetadata::ParameterMetadata( + const c10::Scalar& scalar, + uint64_t input_order) + : tag_(SCALAR), order_(input_order) { + value_ = scalar; +} + +bool ParameterMetadata::operator==(const ParameterMetadata& other) const { + // Same type + if (tag_ != other.tag_) { + return false; + } + + // Same order of the input parameters + if (order_ != other.order_) { + return false; + } + + switch (tag_) { + case TENSOR: + return std::get(value_) == + std::get(other.value_); + case TENSOR_LIST: + return std::get>(value_) == + std::get>(other.value_); + case SCALAR: + TORCH_INTERNAL_ASSERT( + std::get(other.value_).isFloatingPoint() || + std::get(other.value_).isIntegral(true /*includeBool*/)); + return equal_to(std::get(other.value_)); + default: + return false; } - return hash; } -size_t AOTIKernelMetadataHash::operator()( - const AOTIKernelMetadata& aoti_kernel_metadata) const { - size_t hash = 0; - for (auto& e : aoti_kernel_metadata) { - hash = c10::hash_combine(hash, TensorMetadataHash()(e)); +bool ParameterMetadata::equal_to(const c10::Scalar& scalar) const { + TORCH_INTERNAL_ASSERT(scalar.isFloatingPoint() || scalar.isIntegral(true)); + if (tag_ != SCALAR) { + return false; } - return hash; + + auto self_scalar = std::get(value_); + if (scalar.isFloatingPoint() && self_scalar.isFloatingPoint()) { + return self_scalar.toDouble() == scalar.toDouble(); + } else if (scalar.isIntegral(true) && self_scalar.isIntegral(true)) { + return self_scalar.toInt() == scalar.toInt(); + } + + return false; } } // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index 5c22e9b75f65bc..d07814dd0ad9ca 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -3,6 +3,7 @@ #include #include +#include #include @@ -32,44 +33,102 @@ struct TensorMetadata { // and strides_ in the future. bool is_symbolic_; // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor) - c10::ScalarType dtype_; - // Concrete scalar value. Serve for operations w/ scalar parameter - c10::IValue scalar_value_; + c10::ScalarType dtype_ = c10::ScalarType::Undefined; // Device of a tensor. c10::Device device_; + // Dispatch key set of a tensor + c10::DispatchKeySet dispatch_key_set_; // Sizes of a tensor. Currently, we only support static shape and use int64_t // to represent the sizes. In the future, we will create symbolic size and use // SymInt to represent it to support symbolic shape. std::vector sizes_; // Strides of a tensor. For symbolic shape support, it is the same as sizes_ std::vector strides_; + // requires grad + bool requires_grad_ = false; + // TensorCheck for the tensor + std::optional tensor_check_; + TensorMetadata() + : is_symbolic_(false), + dtype_(c10::ScalarType::Undefined), + device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES), + sizes_({}), + strides_({}) {} TensorMetadata(const at::Tensor& src_tensor); TensorMetadata( bool is_symbolic, c10::ScalarType dtype, c10::Device device, + c10::DispatchKeySet dispatch_key_set, std::vector sizes, - std::vector strides); - TensorMetadata( - bool is_symbolic, - c10::ScalarType dtype, - c10::IValue scalar_value, - c10::Device device, - std::vector sizes, - std::vector strides); + std::vector strides, + bool requires_grad = false); + // Build TensorCheck for the tensor by using the data fields in TensorMetadata + void build_guard(const dynamo::LocalState& local_state); + + // Compare two TensorMetadata objects bool operator==(const TensorMetadata& other) const; }; -struct TensorMetadataHash { - size_t operator()(const TensorMetadata&) const; +// ParameterTag is to represent the type of the input parameters of a aten +// operation. Currently, we support the following types: +// 1. TENSOR: a single tensor +// 2. TENSOR_OPTIONAL: a single optional tensor +// 3. TENSOR_LIST: a list of tensors +// 4. TENSOR_LIST_OPTIONAL: a list of optional tensors +// 5. SCALAR: a scalar value +// If we need to support more types in the future, we will add more types in the +// ParameterTag enum. For example, we will extend the enum to support string, +// Dimname and so on to support more types of input parameters of aten +// operations. +enum ParameterTag { + TENSOR, + TENSOR_OPTIONAL, + TENSOR_LIST, + TENSOR_LIST_OPTIONAL, + SCALAR, + INVALID, }; -using AOTIKernelMetadata = std::vector; +// ParameterMetadataValue is to represent the value of the input parameters of a +// aten operation. +using ParameterMetadataValue = + std::variant, c10::Scalar>; + +// ParameterMetadata is to represent the metadata of the input parameters of a +// aten operation. It includes the tag of the parameter, the value of the +// parameter and the order of the parameter. +struct ParameterMetadata { + // The tag of the parameter. It indicates the type of the parameter. + ParameterTag tag_; + // The value of the parameter. It can be a tensor, a list of tensors or a + // scalar. + ParameterMetadataValue value_; + // The order of the parameter is used to distinguish the parameters with the + // same tag. For example, an operation with two input tensors, the first + // tensor is a optional tensor and the second tensor is a tensor. The first + // tensor will have the order 0 and the second tensor will have the order 1. + uint64_t order_; + + ParameterMetadata() : tag_(INVALID) {} + ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order); + ParameterMetadata(const at::Tensor& tensor, uint64_t input_order); + ParameterMetadata( + const std::vector& tensor_list, + uint64_t input_order); + ParameterMetadata( + const std::vector& tensor_metadata_list, + uint64_t input_order); + ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order); + + bool operator==(const ParameterMetadata& other) const; -struct AOTIKernelMetadataHash { - size_t operator()(const AOTIKernelMetadata&) const; + private: + // Helper function to compare two ParameterMetadata objects with the same + // SCALAR tag. + bool equal_to(const c10::Scalar& scalar) const; }; } // namespace torch::inductor diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index ec0af99842d2e5..c3a7aafa95f2dd 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -716,6 +716,7 @@ void initDispatchBindings(PyObject* module) { .def("__or__", &c10::DispatchKeySet::operator|) .def("__sub__", &c10::DispatchKeySet::operator-) .def("__and__", &c10::DispatchKeySet::operator&) + .def("raw_repr", &c10::DispatchKeySet::raw_repr) .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId) .def( "remove", From 7a39755da28d5a109bf0c37f72b364d3a83137b1 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Fri, 14 Jun 2024 19:37:55 -0700 Subject: [PATCH 058/171] Introduce a prototype for SymmetricMemory (#128582) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector get_buffer_ptrs() = 0; virtual std::vector get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582 Approved by: https://github.com/wanchaol --- .lintrunner.toml | 1 + BUILD.bazel | 1 + build_variables.bzl | 2 + c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 + test/distributed/test_symmetric_memory.py | 156 +++++ torch/_C/_distributed_c10d.pyi | 30 + .../distributed/c10d/CUDASymmetricMemory.cu | 539 ++++++++++++++++++ .../distributed/c10d/CUDASymmetricMemory.cuh | 109 ++++ .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 + .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ++++++ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 +++++ torch/csrc/distributed/c10d/init.cpp | 39 ++ .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +--- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 1254 insertions(+), 111 deletions(-) create mode 100644 test/distributed/test_symmetric_memory.py create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu create mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp create mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index 07e64fce799d87..d4e81d1e68a641 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,6 +68,7 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', + 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index 10c065f5084c7e..c563c52d861e67 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,6 +744,7 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 323588c15b4c1a..b4b4d1ab139cd9 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,6 +501,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", + "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -684,6 +685,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 43bcbd1d70bace..cbbdf16823ec76 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,14 +18,17 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ + _(cuMemGetAllocationGranularity) \ + _(cuMemExportToShareableHandle) \ + _(cuMemImportFromShareableHandle) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 89c31fab113473..8426741609fe7f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,6 +560,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py new file mode 100644 index 00000000000000..a768e059044f79 --- /dev/null +++ b/test/distributed/test_symmetric_memory.py @@ -0,0 +1,156 @@ +# Owner(s): ["module: c10d"] + +import torch + +import torch.distributed as dist +from torch._C._distributed_c10d import _SymmetricMemory +from torch.distributed.distributed_c10d import _get_process_group_store + +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + skip_but_pass_in_sandcastle_if, + skipIfRocm, +) + + +def requires_cuda_p2p_access(): + cuda_p2p_access_available = ( + torch.cuda.is_available() and torch.cuda.device_count() >= 2 + ) + num_devices = torch.cuda.device_count() + for i in range(num_devices - 1): + for j in range(i + 1, num_devices): + if not torch.cuda.can_device_access_peer(i, j): + cuda_p2p_access_available = False + break + if not cuda_p2p_access_available: + break + + return skip_but_pass_in_sandcastle_if( + not cuda_p2p_access_available, + "cuda p2p access is not available", + ) + + +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SymmetricMemoryTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 2 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + _SymmetricMemory.set_group_info( + "0", + self.rank, + self.world_size, + _get_process_group_store(dist.GroupMember.WORLD), + ) + + def _verify_symmetric_memory(self, symm_mem): + self.assertEqual(symm_mem.world_size, 2) + + buf = symm_mem.get_buffer(0, (64, 64), torch.float32) + if symm_mem.rank == 0: + symm_mem.wait_signal(src_rank=1) + self.assertTrue(buf.eq(42).all()) + else: + buf.fill_(42) + symm_mem.put_signal(dst_rank=0) + + symm_mem.barrier() + + if symm_mem.rank == 0: + symm_mem.barrier() + self.assertTrue(buf.eq(43).all()) + else: + buf.fill_(43) + symm_mem.barrier() + + symm_mem.barrier() + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name) + + t = torch.empty(shape, dtype=dtype, device=device) + with self.assertRaises(RuntimeError): + _SymmetricMemory.rendezvous(t) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + del t + self._verify_symmetric_memory(symm_mem) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_empty_strided_p2p_persistent(self) -> None: + self._init_process() + + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + alloc_id = 42 # Persistent allocation + group_name = "0" + alloc_args = (shape, stride, dtype, device, group_name, alloc_id) + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + data_ptr = t.data_ptr() + + # Verify that persistent allocation would fail if there's an active + # allocation with the same alloc_id. + with self.assertRaises(RuntimeError): + _SymmetricMemory.empty_strided_p2p(*alloc_args) + + # Verify that persistent allocation would succeed in lieu of activate + # allocations with the same alloc_id, and the returned tensor would + # have the same data pointer. + del t + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + self.assertEqual(t.data_ptr(), data_ptr) + + # Verify that get_symmetric_memory would fail if called before + # rendezvous. + with self.assertRaises(RuntimeError): + _SymmetricMemory.get_symmetric_memory(t) + + symm_mem_0 = _SymmetricMemory.rendezvous(t) + symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) + self.assertEqual(id(symm_mem_0), id(symm_mem_1)) + + self._verify_symmetric_memory(symm_mem_0) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cffbf22219c8e7..0095b5af434b5c 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,3 +637,33 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class _SymmetricMemory: + @staticmethod + def set_group_info( + group_name: str, rank: int, world_size: int, store: Store + ) -> None: ... + @staticmethod + def empty_strided_p2p( + size: torch.types._size, + stride: torch.types._size, + dtype: torch.dtype, + device: torch.device, + group_name: str, + ) -> torch.Tensor: ... + @property + def rank(self) -> int: ... + @property + def world_size(self) -> int: ... + @staticmethod + def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... + def get_buffer( + self, + rank: int, + sizes: torch.Size, + dtype: torch.dtype, + storage_offset: Optional[int] = 0, + ) -> torch.Tensor: ... + def barrier(self, channel: int = 0) -> None: ... + def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... + def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu new file mode 100644 index 00000000000000..f27db85f7ff85d --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -0,0 +1,539 @@ +#include + +#include +#include +#include +#include + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#endif + +#include +#include + +namespace { + +constexpr size_t signal_pad_size = 2048; +const std::string store_comm_prefix = "CUDASymmetricMemory"; + +static size_t store_comm_seq_id = 0; + +template +std::vector store_all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; + peer_keys.push_back(oss.str()); + } + ++store_comm_seq_id; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; +} + +void store_barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + store_all_gather(store, rank, world_size, 0); +} + +int import_remote_fd(int pid, int fd) { +#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) + int pidfd = syscall(SYS_pidfd_open, pid, 0); + return syscall(SYS_pidfd_getfd, pidfd, fd, 0); +#else + TORCH_CHECK( + false, + "CUDASymmetricMemory requires pidfd_open ", + "and pidfd_getfd support"); +#endif +} + +void map_block( + void** ptr, + c10d::symmetric_memory::HandleType handle, + size_t size, + int device_idx) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + auto dev_ptr = reinterpret_cast(ptr); + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); + + CUmemAccessDesc desc; + desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + desc.location.id = static_cast(device_idx); + desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +CUDASymmetricMemory::CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : handles_(std::move(handles)), + block_size_(block_size), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); + + c10::cuda::CUDAGuard guard(local_device_idx); + AT_CUDA_CHECK(cudaMemcpy( + buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); + AT_CUDA_CHECK(cudaMemcpy( + signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); +} + +CUDASymmetricMemory::~CUDASymmetricMemory() { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + c10::cuda::CUDAGuard guard(local_device_idx_); + C10_CUDA_CHECK(cudaDeviceSynchronize()); + + auto driver_api = c10::cuda::DriverAPI::get(); + for (int r = 0; r < world_size_; ++r) { + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(buffers_[r]), block_size_)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); + } + c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); + c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +std::vector CUDASymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** CUDASymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t CUDASymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t CUDASymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +at::Tensor CUDASymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const auto numel = + std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "CUDASymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storage_offset) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +__device__ __forceinline__ void release_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); +#endif +} + +__device__ __forceinline__ void acquire_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + volatile uint32_t* signal = addr; + uint32_t val; + do { + val = *signal; + } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); +#endif +} + +static __global__ void barrier_kernel( + uint32_t** signal_pads, + int channel, + int rank, + int world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + release_signal(signal_pads[target_rank] + world_size * channel + rank); + acquire_signal(signal_pads[rank] + world_size * channel + target_rank); + } +} + +void CUDASymmetricMemory::barrier(int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void put_signal_kernel( + uint32_t** signal_pads, + int dst_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + release_signal(signal_pads[dst_rank] + world_size * channel + rank); + } +} + +void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + dst_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +static __global__ void wait_signal_kernel( + uint32_t** signal_pads, + int src_rank, + int channel, + int rank, + int world_size) { + if (threadIdx.x == 0) { + acquire_signal(signal_pads[rank] + world_size * channel + src_rank); + } + __threadfence_system(); +} + +void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { + check_channel(channel, world_size_); + c10::cuda::CUDAGuard guard(local_device_idx_); + wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(signal_pads_dev_), + src_rank, + channel, + rank_, + world_size_); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +int CUDASymmetricMemory::get_rank() { + return rank_; +} + +int CUDASymmetricMemory::get_world_size() { + return world_size_; +} + +void* CUDASymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::string& group_name) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto driver_api = c10::cuda::DriverAPI::get(); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + prop.location.id = device_idx; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + size_t granularity; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + block_size = at::round_up(block_size, granularity); + + HandleType handle; + C10_CUDA_DRIVER_CHECK( + driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); + + void* ptr = nullptr; + map_block(&ptr, handle, block_size, device_idx); + + c10::cuda::CUDAGuard guard(device_idx); + AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); + + auto block = c10::make_intrusive( + handle, device_idx, block_size, size, signal_pad_offset, group_name); + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + return ptr; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +void CUDASymmetricMemoryAllocator::free(void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + if (block == nullptr) { + return; + } + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. + if (block->symm_mem == nullptr) { + auto driver_api = c10::cuda::DriverAPI::get(); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( + reinterpret_cast(ptr), block->block_size)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); + } + { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); + } +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int block_fd; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; +}; + +void validate_rendezvous_requests( + const std::vector reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + if (device_indices.size() < (size_t)world_size) { + TORCH_CHECK( + false, + "CUDASymmetricMemoryAllocator::rendezvous: ", + "detected allocations from overlapping devices ", + "from different ranks."); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( + void* ptr) { +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + + if (block->symm_mem != nullptr) { + return block->symm_mem; + } + + auto group_info = get_group_info(block->group_name); + auto driver_api = c10::cuda::DriverAPI::get(); + int block_fd; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .block_fd = block_fd, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset}; + auto reqs = store_all_gather( + group_info.store, group_info.rank, group_info.world_size, local_req); + validate_rendezvous_requests(reqs, group_info.world_size); + + std::vector handles(group_info.world_size); + std::vector buffers(group_info.world_size, nullptr); + std::vector signal_pads(group_info.world_size, nullptr); + for (int r = 0; r < group_info.world_size; ++r) { + if (r == group_info.rank) { + handles[r] = block->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } + int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)(uintptr_t)imported_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + map_block(&buffers[r], handles[r], block->block_size, block->device_idx); + signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + close(imported_fd); + } + store_barrier(group_info.store, group_info.rank, group_info.world_size); + close(block_fd); + + // Initializing CUDASymmetricMemory with an allocation transfers its + // ownership to the CUDASymmetricMemory object. So that outstanding + // references to the CUDASymmetricMemory object can keep the allocation + // alive. + block->symm_mem = c10::make_intrusive( + std::move(handles), + block->block_size, + std::move(buffers), + std::move(signal_pads), + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + return block->symm_mem; +#else + TORCH_CHECK( + false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); +#endif +} + +bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", + "via CUDASymmetricMemoryAllocator::alloc"); + return block->symm_mem != nullptr; +} + +c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterCUDASymmetricMemoryAllocator { + RegisterCUDASymmetricMemoryAllocator() { + register_allocator( + c10::DeviceType::CUDA, + c10::make_intrusive()); + } +}; + +static RegisterCUDASymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh new file mode 100644 index 00000000000000..0e0e40a6bd0910 --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +using HandleType = CUmemGenericAllocationHandle; +#else +using HandleType = void*; +#endif + +class CUDASymmetricMemory : public SymmetricMemory { + public: + CUDASymmetricMemory( + std::vector handles, + size_t block_size, + std::vector buffers, + std::vector signal_pads, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~CUDASymmetricMemory() override; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) override; + + void barrier(int channel) override; + void put_signal(int dst_rank, int channel) override; + void wait_signal(int src_rank, int channel) override; + + int get_rank() override; + int get_world_size() override; + + private: + std::vector handles_; + size_t block_size_; + std::vector buffers_; + std::vector signal_pads_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::optional> finalizer_; +}; + +struct Block : public c10::intrusive_ptr_target { + HandleType handle; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::string group_name; + c10::intrusive_ptr symm_mem = nullptr; + + Block( + HandleType handle, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::string& group_name) + : handle(handle), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + group_name(group_name), + symm_mem(nullptr) {} +}; + +class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::string& group_name) override; + + void free(void *ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous(void* ptr) override; + bool is_rendezvous_completed(void* ptr) override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index cff4ad09b70648..7c41414c4e4e17 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,6 +10,7 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp new file mode 100644 index 00000000000000..b3d9f31bb03420 --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -0,0 +1,189 @@ +#include + +namespace { + +using namespace c10d::symmetric_memory; + +class AllocatorMap { + public: + static AllocatorMap& get() { + static AllocatorMap instance; + return instance; + } + + void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + map_[device_type] = std::move(allocator); + } + + c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + auto it = map_.find(device_type); + TORCH_CHECK( + it != map_.end(), + "SymmetricMemory does not support device type ", + device_type); + return it->second; + } + + ~AllocatorMap() { + for (auto& it : map_) { + it.second.release(); + } + } + + private: + AllocatorMap() = default; + AllocatorMap(const AllocatorMap&) = delete; + AllocatorMap& operator=(const AllocatorMap&) = delete; + + std::unordered_map< + c10::DeviceType, + c10::intrusive_ptr> + map_; +}; + +static std::unordered_map group_info_map{}; + +// Data structures for tracking persistent allocations +static std::unordered_map alloc_id_to_dev_ptr{}; +static std::unordered_map> + alloc_id_to_storage{}; + +static at::Tensor empty_strided_p2p_persistent( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + uint64_t alloc_id) { + // Make the allocation fails if a previous allocation with the same alloc_id + // is still active. + auto storage = alloc_id_to_storage.find(alloc_id); + if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { + TORCH_CHECK( + false, + "SymmetricMemory::empty_strided_p2p_persistent: ", + "can not allocate with alloc_id == ", + alloc_id, + " because a previous allocation with the same alloc_id " + "is still active."); + } + + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = nullptr; + if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { + dev_ptr = alloc_id_to_dev_ptr[alloc_id]; + TORCH_CHECK( + alloc_size == allocator->get_alloc_size(dev_ptr), + "SymmetricMemory::empty_strided_p2p_persistent: ", + "requested allocation size (", + alloc_size, + ") is different from the size of a previous allocation ", + "with the same alloc_id ", + allocator->get_alloc_size(dev_ptr)); + } else { + dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + alloc_id_to_dev_ptr[alloc_id] = dev_ptr; + } + + auto options = at::TensorOptions().dtype(dtype).device(device); + auto allocated = at::from_blob(dev_ptr, size, stride, options); + + // Track the allocation's activeness + alloc_id_to_storage.erase(alloc_id); + alloc_id_to_storage.emplace( + alloc_id, allocated.storage().getWeakStorageImpl()); + return allocated; +} + +} // namespace + +namespace c10d { +namespace symmetric_memory { + +void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator) { + return AllocatorMap::get().register_allocator( + device_type, std::move(allocator)); +} + +c10::intrusive_ptr get_allocator( + c10::DeviceType device_type) { + return AllocatorMap::get().get_allocator(device_type); +} + +void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store) { + TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); + GroupInfo group_info; + group_info.rank = rank; + group_info.world_size = world_size; + group_info.store = std::move(store); + group_info_map.emplace(group_name, std::move(group_info)); +} + +const GroupInfo& get_group_info(const std::string& group_name) { + TORCH_CHECK( + group_info_map.find(group_name) != group_info_map.end(), + "get_group_info: no group info associated with the group name ", + group_name); + return group_info_map[group_name]; +} + +at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id) { + if (alloc_id.has_value()) { + return empty_strided_p2p_persistent( + size, stride, dtype, device, group_name, *alloc_id); + } + const size_t numel = + std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t element_size = c10::elementSize(dtype); + const size_t alloc_size = numel * element_size; + + auto allocator = get_allocator(device.type()); + void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); + + auto options = at::TensorOptions().dtype(dtype).device(device); + return at::from_blob( + dev_ptr, + size, + stride, + [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, + options); +} + +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + return allocator->rendezvous(tensor.data_ptr()); +} + +c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor) { + auto allocator = get_allocator(tensor.device().type()); + TORCH_CHECK( + allocator->is_rendezvous_completed(tensor.data_ptr()), + "SymmetricMemory: must invoke rendezvous on a tensor ", + "before calling get_symmetric_memory on it"); + return allocator->rendezvous(tensor.data_ptr()); +} + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp new file mode 100644 index 00000000000000..344b86ea5c7e3a --- /dev/null +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include +#include + +namespace c10d { +namespace symmetric_memory { + +// SymmetricMemory represents symmetric allocations across a group of devices. +// The allocations represented by a SymmetricMemory object are accessible by +// all devices in the group. The class can be used for op-level custom +// communication patterns (via the get_buffer APIs and the synchronization +// primitives), as well as custom communication kernels (via the buffer and +// signal_pad device pointers). +// +// To acquire a SymmetricMemory object, each rank first allocates +// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes +// SymmetricMemoryAllocator::rendezvous() on the memory to establish the +// association across peer buffers. The rendezvous is a one-time process, and +// the mapping between a local memory memory and the associated SymmetricMemory +// object is unique. +// +// NOTE [symmetric memory signal pad] +// Signal pads are P2P-accessible memory regions designated for +// synchronization. SymmetricMemory offers built-in synchronization primitives +// such as barriers, put_signal, and wait_signal, which are all based on signal +// pads. Users may utilize signal pads for their own synchronization logic, +// provided that the signal pads remain zero-filled following successful +// synchronization. +// +// NOTE [symmetric memory synchronization channel] +// Synchronization channels allow users to use a single SymmetricMemory object +// to perform isolated synchronizations on different streams. For example, +// consider the case in which two barriers are issued on two streams for +// different purposes. Without the concept of channels, we cannot guarantee the +// correctness of the barriers since signals issued from barrier on stream A +// can be received by the barrier on stream B. By specifying different channels +// for these two barriers, they can operate correctly in parallel. +class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemory() {} + + virtual std::vector get_buffer_ptrs() = 0; + virtual std::vector get_signal_pad_ptrs() = 0; + + // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer + // to a device array of size world_size, containing buffer pointers and + // signal pad pointers, respectively. + virtual void** get_buffer_ptrs_dev() = 0; + virtual void** get_signal_pad_ptrs_dev() = 0; + virtual size_t get_buffer_size() = 0; + virtual size_t get_signal_pad_size() = 0; + + virtual at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) = 0; + + virtual void barrier(int channel) = 0; + virtual void put_signal(int dst_rank, int channel) = 0; + virtual void wait_signal(int src_rank, int channel) = 0; + + virtual int get_rank() = 0; + virtual int get_world_size() = 0; +}; + +class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { + public: + virtual ~SymmetricMemoryAllocator(){}; + + virtual void* alloc( + size_t size, + int device_idx, + const std::string& group_name) = 0; + + virtual void free(void* ptr) = 0; + virtual size_t get_alloc_size(void* ptr) = 0; + virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; + virtual bool is_rendezvous_completed(void* ptr) = 0; +}; + +C10_EXPORT void register_allocator( + c10::DeviceType device_type, + c10::intrusive_ptr allocator); + +C10_EXPORT c10::intrusive_ptr get_allocator( + c10::DeviceType device_type); + +// Set a store for rendezvousing symmetric allocations on a group of devices +// identified by `group_name`. The concept of groups is logical; users can +// utilize predefined groups (e.g., a group of device identified by a +// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator +// backends might employ a more efficient communication channel for the actual +// rendezvous process and only use the store for bootstrapping purposes. +TORCH_API void set_group_info( + const std::string& group_name, + int rank, + int world_size, + c10::intrusive_ptr store); + +struct GroupInfo { + int rank; + int world_size; + c10::intrusive_ptr store; +}; + +C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); + +// Identical to empty_strided, but allows symmetric memory access to be +// established for the allocated tensor via SymmetricMemory::rendezvous(). This +// function itself is not a collective operation. It invokes +// SymmetricMemoryAllocator::alloc() for the requested device under the hood. +// +// NOTE [symmetric memory persistent allocation] +// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent +// allocation. This makes the function cache allocated memory and ensure that +// invocations with the same `alloc_id` receive tensors backed by the same +// memory address. For safety, if a previous persistent allocation is still +// active (i.e., the storage of the returned tensor is still alive), persistent +// allocations with the same `alloc_id` will fail. This determinism coupled +// with memory planning of communication buffers (e.g., by Inductor) allows +// communication algorithms to reliably reuse previously established remote +// memory access. +TORCH_API at::Tensor empty_strided_p2p( + c10::IntArrayRef size, + c10::IntArrayRef stride, + c10::ScalarType dtype, + c10::Device device, + const std::string& group_name, + std::optional alloc_id); + +// Establishes symmetric memory access on tensors allocated via +// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a +// one-time process, and the mapping between a local memory region and the +// associated SymmetricMemory object is unique. Subsequent calls to +// rendezvous() with the same tensor, or tensors allocated with +// empty_strided_p2p_persistent() using the same alloc_id, will receive the +// cached SymmetricMemory object. +// +// The function has a collective semantic and must be invoked simultaneously +// from all rendezvous participants. +TORCH_API c10::intrusive_ptr rendezvous( + const at::Tensor& tensor); + +// Returns the SymmetricMemory object associated with the tensor. It can only +// be invoked after rendezvous() but does not need to be invoked collectively. +TORCH_API c10::intrusive_ptr get_symmetric_memory( + const at::Tensor& tensor); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f1b28886b989b..db5778efcf3547 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include @@ -975,6 +976,44 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); + using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; + py::class_>( + module, "_SymmetricMemory") + .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) + .def_static( + "empty_strided_p2p", + ::c10d::symmetric_memory::empty_strided_p2p, + py::arg("size"), + py::arg("stride"), + py::arg("dtype"), + py::arg("device"), + py::arg("group_name"), + py::arg("alloc_id") = py::none()) + .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) + .def_static( + "get_symmetric_memory", + &::c10d::symmetric_memory::get_symmetric_memory) + .def_property_readonly("rank", &SymmetricMemory::get_rank) + .def_property_readonly("world_size", &SymmetricMemory::get_world_size) + .def( + "get_buffer", + &SymmetricMemory::get_buffer, + py::arg("rank"), + py::arg("sizes"), + py::arg("dtype"), + py::arg("storage_offset") = 0) + .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) + .def( + "put_signal", + &SymmetricMemory::put_signal, + py::arg("dst_rank"), + py::arg("channel") = 0) + .def( + "wait_signal", + &SymmetricMemory::wait_signal, + py::arg("src_rank"), + py::arg("channel") = 0); + auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 85136a91e02564..9d7ba5abf951dd 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,23 +218,8 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - // Intentionally releasing resources without synchronizing devices. The - // teardown logic is safe for propoerly sync'd user program. We don't want - // improperly sync'd user program to hang here. - for (size_t r = 0; r < worldSize_; ++r) { - if (r == rank_) { - continue; - } - AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); - AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); - } - AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); - AT_CUDA_CHECK(cudaFree(buffers_[rank_])); - if (topoInfo_ != nullptr) { - AT_CUDA_CHECK(cudaFree(topoInfo_)); - } - AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); - AT_CUDA_CHECK(cudaFree(buffersDev_)); + auto allocator = get_allocator(c10::DeviceType::CUDA); + allocator->free(symmetricMemoryPtr_); } bool IntraNodeComm::isEnabled() { @@ -344,83 +329,19 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - // Initialize p2p state - auto p2pState = initP2pState(); - - // Allocate buffer - void* buffer = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); - - // Second handshake: exchange topology and CUDA IPC handles - struct IpcInfo { - NvlMesh nvlMesh; - Topology topology; - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - }; - - // Make p2p state and buffer available for IPC - cudaIpcMemHandle_t p2pStateHandle, bufferHandle; - AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); - AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); - - IpcInfo ipcInfo{ - .nvlMesh = nvlMesh, - .topology = topology, - .p2pStateHandle = p2pStateHandle, - .bufferHandle = bufferHandle}; - - auto peerIpcInfos = - storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); - - for (const auto& info : peerIpcInfos) { - if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || - info.topology != peerIpcInfos.front().topology) { - LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " - "participants are observing different topologies (" - << int(info.topology) << " and " << int(topology) << ")"; - AT_CUDA_CHECK(cudaFree(p2pState)); - AT_CUDA_CHECK(cudaFree(buffer)); - return false; - } - } - - std::array p2pStates = {}, buffers = {}; - for (size_t r = 0; r < peerIpcInfos.size(); ++r) { - if (r == rank_) { - p2pStates[r] = p2pState; - buffers[r] = buffer; - } else { - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &p2pStates[r], - peerIpcInfos[r].p2pStateHandle, - cudaIpcMemLazyEnablePeerAccess)); - AT_CUDA_CHECK(cudaIpcOpenMemHandle( - &buffers[r], - peerIpcInfos[r].bufferHandle, - cudaIpcMemLazyEnablePeerAccess)); - } - } - void* p2pStatesDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); - AT_CUDA_CHECK(cudaMemcpy( - p2pStatesDev, - p2pStates.data(), - sizeof(p2pStates), - cudaMemcpyHostToDevice)); - - void* buffersDev = nullptr; - AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); - AT_CUDA_CHECK(cudaMemcpy( - buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); + set_group_info("IntraNodeComm", rank_, worldSize_, store_); + auto allocator = get_allocator(c10::DeviceType::CUDA); + symmetricMemoryPtr_ = + allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); + symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); + TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); - std::copy(buffers.begin(), buffers.end(), buffers_.begin()); - p2pStatesDev_ = p2pStatesDev; - buffersDev_ = buffersDev; + p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); + buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index 51fc6252d2235b..ac751ff7be1e09 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,6 +132,8 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; +static_assert(sizeof(P2pState) <= kP2pStateSize); + template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -522,7 +524,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -582,7 +584,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -632,7 +634,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - buffers_[rank_], + symmetricMemory_->get_buffer_ptrs_dev()[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -755,15 +757,7 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); - const auto elementSize = c10::elementSize(dtype); - TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); - auto options = at::TensorOptions().dtype(dtype).device( - at::kCUDA, at::cuda::current_device()); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storageOffset) - .options(options) - .make_tensor(); + return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index 5d7e2d426d30a1..a67df5c34586a0 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,12 +4,16 @@ #include #include #include +#include #include namespace c10d::intra_node_comm { +using namespace c10d::symmetric_memory; + constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; +constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -27,6 +31,7 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; +// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -97,8 +102,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - std::array p2pStates_{}; - std::array buffers_{}; + void* symmetricMemoryPtr_ = nullptr; + c10::intrusive_ptr symmetricMemory_ = nullptr; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From 18634048a1f939a961b7c96b0acfe78b474c821e Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Sat, 15 Jun 2024 09:08:32 +0000 Subject: [PATCH 059/171] Separate AOTI Eager utils as a single file (#125819) The key change is code movement. We just moved aoti eager related code from `torch._inductor.utils` to `torch._inductor.aoti_eager` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125819 Approved by: https://github.com/jansel, https://github.com/jgong5, https://github.com/desertfire ghstack dependencies: #125308 --- test/inductor/test_torchinductor.py | 11 +- torch/_inductor/aoti_eager.py | 225 ++++++++++++++++++ torch/_inductor/utils.py | 206 +--------------- .../inductor/aoti_eager/kernel_holder.cpp | 9 +- 4 files changed, 238 insertions(+), 213 deletions(-) create mode 100644 torch/_inductor/aoti_eager.py diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 7c66547400d68c..b8be175143a285 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -28,6 +28,7 @@ import torch import torch._dynamo.config as dynamo_config +import torch._inductor.aoti_eager import torch.nn as nn from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.debug_utils import aot_graph_input_parser @@ -39,14 +40,16 @@ skipIfPy312, ) from torch._dynamo.utils import ifdynstaticdefault +from torch._inductor.aoti_eager import ( + aoti_compile_with_persistent_cache, + aoti_eager_cache_dir, + load_aoti_eager_cache, +) from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext from torch._inductor.fx_passes import pad_mm from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import ( add_scheduler_init_hook, - aoti_compile_with_persistent_cache, - aoti_eager_cache_dir, - load_aoti_eager_cache, run_and_get_code, run_and_get_cpp_code, run_and_get_triton_code, @@ -846,7 +849,7 @@ def test_aoti_eager_cache_hit(self): # Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated with mock.patch( - "torch._inductor.utils.aoti_compile_with_persistent_cache", None + "torch._inductor.aoti_eager.aoti_compile_with_persistent_cache", None ): with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: # Get ref result from eager diff --git a/torch/_inductor/aoti_eager.py b/torch/_inductor/aoti_eager.py new file mode 100644 index 00000000000000..d77c764a00e129 --- /dev/null +++ b/torch/_inductor/aoti_eager.py @@ -0,0 +1,225 @@ +import json +import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest import mock + +import torch +import torch._export +from torch._inductor.utils import is_cpu_device +from .runtime.runtime_utils import cache_dir + + +def aoti_eager_cache_dir(namespace: str, device: str) -> Path: + return Path(cache_dir()) / "aoti_eager" / namespace / device + + +def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: + from filelock import FileLock + + # Avoid circular import + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + op_conf_lock_file = f"{op_func_name_with_overload}.lock" + lock_dir = get_lock_dir() + return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) + + +def load_aoti_eager_cache( + ns: str, op_func_name_with_overload: str, device_type: str +) -> List[Optional[Dict[str, Any]]]: + device_kernel_cache = aoti_eager_cache_dir(ns, device_type) + op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" + if not op_conf.exists(): + return [] + + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf) as f: + json_data = json.load(f) + for item in json_data: + # Get absolution path for kernel library + kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] + item["kernel_path"] = kernel_lib_abs_path.as_posix() + + # Check if the kernel library exists + if not kernel_lib_abs_path.exists(): + return [] + + for metadata in item["meta_info"]: + assert not metadata[ + "is_dynamic" + ], "Only support static shape for now" + if metadata["device_type"] == "cpu": + metadata["device_index"] = -1 + metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1]) + + return json_data + + +def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: + return {int: torch.int32, float: torch.float, bool: torch.bool} + + +def supported_scalar_types() -> Tuple[type, ...]: + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + supported_scalar_types = tuple(type_to_torch_dtype.keys()) + return supported_scalar_types + + +def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + +def extract_tensor_list_metadata( + dynamic: bool, + input: List[torch.Tensor], +) -> Dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(dynamic, item)) + + metadata: Dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + +def extract_scalar_metadata( + device_type: str, input: Union[int, float, bool] +) -> Dict[str, Any]: + assert isinstance(input, supported_scalar_types()) + metadata: Dict[str, Any] = {} + metadata["is_dynamic"] = False + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + +def aoti_compile_with_persistent_cache( + ns: str, + op_func_name_with_overload: str, + device_type: str, + dynamic: bool, + f: Callable[..., Any], + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + dynamic_shapes: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Compile the given function with persistent cache for AOTI eager mode. + """ + assert not dynamic, "Only support static shape for now" + type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} + supported_scalar_types = tuple(type_to_torch_dtype.keys()) + flattened_inputs = list(args) + list(kwargs.values()) + if not all( + isinstance(input, (supported_scalar_types, torch.Tensor, list)) + for input in flattened_inputs + ): + raise NotImplementedError( + "Only support tensor, tensor list, int, float, bool for now" + ) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + raise NotImplementedError( + "Regarding list, _impl_with_aoti_compile only support tensor list now." + ) + + persistent_cache = aoti_eager_cache_dir(ns, device_type) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + + persistent_cache_lib = persistent_cache / "lib" + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() + + with mock.patch.dict( + os.environ, + {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, + ): + try: + kernel_lib_path = torch._export.aot_compile( + f, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + remove_runtime_assertions=remove_runtime_assertions, + disable_constraint_solver=disable_constraint_solver, + # Some operations may have non-Tensor parameters like int, float, bool. These + # non-Tensor parameters will not be the input of the graph. Therefore, we do + # need to keep the same signature. + same_signature=False, + ) + + kernel_metadata_items = [] + + for idx, input in enumerate(flattened_inputs): + if isinstance(input, torch.Tensor): + metadata = extract_tensor_metadata(dynamic, input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(dynamic, input) + else: + metadata = extract_scalar_metadata(device_type, input) + + metadata["arg_order"] = idx + kernel_metadata_items.append(metadata) + + kernel_meta_info: Dict[str, Any] = {} + kernel_meta_info["meta_info"] = kernel_metadata_items + kernel_meta_info["kernel_path"] = ( + Path(kernel_lib_path).relative_to(persistent_cache).as_posix() + ) + + json_data = [] + update_json = True + op_conf = persistent_cache / f"{op_func_name_with_overload}.json" + mode = "r" if op_conf.exists() else "w" + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf, mode) as op_conf_file: + try: + json_data = json.load(op_conf_file) + except Exception as e: + json_data = [] + + assert isinstance(json_data, list) + for item in json_data: + assert isinstance(item, dict) + # Same kernel meta info already exists in the json file + if item["meta_info"] == kernel_metadata_items: + update_json = False + break + + if update_json: + json_data.append(kernel_meta_info) + with open(op_conf, "w") as op_conf_file: + json.dump(json_data, op_conf_file, indent=4) + + return kernel_lib_path + except Exception as e: + return "" diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 129ea8c6a45f5b..34e7033a9ed81c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -9,7 +9,6 @@ import inspect import io import itertools -import json import logging import math import operator @@ -23,7 +22,6 @@ import unittest from datetime import datetime from io import StringIO -from pathlib import Path from typing import ( Any, Callable, @@ -35,7 +33,6 @@ Optional, Protocol, Set, - Tuple, TypeVar, Union, ValuesView, @@ -62,7 +59,7 @@ from torch.utils._sympy.symbol import make_symbol, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges from . import config -from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv +from .runtime.runtime_utils import ceildiv as runtime_ceildiv log = logging.getLogger(__name__) @@ -1616,207 +1613,6 @@ def maybe_get_suppress_shape_guards_ctx(): return shape_env.suppress_guards() -def aoti_eager_cache_dir(namespace: str, device: str): - return Path(cache_dir()) / "aoti_eager" / namespace / device - - -def aoti_eager_op_conf_lock(op_func_name_with_overload: str): - from filelock import FileLock - - # Avoid circular import - from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT - - op_conf_lock_file = f"{op_func_name_with_overload}.lock" - lock_dir = get_lock_dir() - return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) - - -def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str): - device_kernel_cache = aoti_eager_cache_dir(ns, device_type) - op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" - if not op_conf.exists(): - return [] - - with aoti_eager_op_conf_lock(op_func_name_with_overload): - with open(op_conf) as f: - json_data = json.load(f) - for item in json_data: - # Get absolution path for kernel library - kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] - item["kernel_path"] = kernel_lib_abs_path.as_posix() - - # Check if the kernel library exists - if not kernel_lib_abs_path.exists(): - return [] - - for metadata in item["meta_info"]: - assert not metadata[ - "is_dynamic" - ], "Only support static shape for now" - if metadata["device_type"] == "cpu": - metadata["device_index"] = -1 - metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1]) - - return json_data - - -def aoti_compile_with_persistent_cache( - ns: str, - op_func_name_with_overload: str, - device_type: str, - dynamic: bool, - f: Callable[..., Any], - args: Tuple[Any], - kwargs: Dict[str, Any], - *, - dynamic_shapes: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - remove_runtime_assertions: bool = False, - disable_constraint_solver: bool = False, -): - """ - Compile the given function with persistent cache for AOTI eager mode. - """ - assert not dynamic, "Only support static shape for now" - from torch._export import aot_compile - - type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} - supported_scalar_types = tuple(type_to_torch_dtype.keys()) - flattened_inputs = list(args) + list(kwargs.values()) - if not all( - isinstance(input, (supported_scalar_types, torch.Tensor, list)) - for input in flattened_inputs - ): - raise NotImplementedError( - "Only support tensor, tensor list, int, float, bool for now" - ) - - for input in flattened_inputs: - if isinstance(input, list) and not all( - isinstance(item, torch.Tensor) for item in input - ): - raise NotImplementedError( - "Regarding list, _impl_with_aoti_compile only support tensor list now." - ) - - persistent_cache = aoti_eager_cache_dir(ns, device_type) - if not persistent_cache.exists(): - persistent_cache.mkdir(parents=True) - - persistent_cache_lib = persistent_cache / "lib" - if not persistent_cache_lib.exists(): - persistent_cache_lib.mkdir() - - with mock.patch.dict( - os.environ, - {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, - ): - try: - kernel_lib_path = aot_compile( - f, - args, - kwargs, - dynamic_shapes=dynamic_shapes, - options=options, - remove_runtime_assertions=remove_runtime_assertions, - disable_constraint_solver=disable_constraint_solver, - # Some operations may have non-Tensor parameters like int, float, bool. These - # non-Tensor parameters will not be the input of the graph. Therefore, we do - # need to keep the same signature. - same_signature=False, - ) - - kernel_metadata_items = [] - - def extract_tensor_metadata(input: torch.Tensor) -> Dict[str, Any]: - metadata: Dict[str, Any] = {} - metadata["is_dynamic"] = dynamic - - assert isinstance(input, torch.Tensor) - metadata["device_type"] = f"{input.device.type}" - if is_cpu_device([input]): - metadata["device_index"] = -1 - else: - metadata["device_index"] = input.device.index - metadata["dtype"] = f"{input.dtype}" - metadata["sizes"] = list(input.size()) - metadata["strides"] = list(input.stride()) - metadata["requires_grad"] = input.requires_grad - metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() - return metadata - - def extract_scalar_metadata( - input: Union[int, float, bool] - ) -> Dict[str, Any]: - assert isinstance(input, supported_scalar_types) - metadata: Dict[str, Any] = {} - metadata["is_dynamic"] = dynamic - # Scalar tensor - metadata["device_type"] = device_type - metadata["device_index"] = -1 if device_type == "cpu" else 0 - metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" - metadata["scalar_value"] = input - return metadata - - def extract_tensor_list_metadata( - input: List[torch.Tensor], - ) -> Dict[str, Any]: - metadata_list = [] - for item in input: - assert isinstance(item, torch.Tensor) - metadata_list.append(extract_tensor_metadata(item)) - - metadata: Dict[str, Any] = {} - metadata["tensor_list"] = metadata_list - return metadata - - for idx, input in enumerate(flattened_inputs): - if isinstance(input, torch.Tensor): - metadata = extract_tensor_metadata(input) - elif isinstance(input, list): - assert all(isinstance(item, torch.Tensor) for item in input) - metadata = extract_tensor_list_metadata(input) - else: - metadata = extract_scalar_metadata(input) - - metadata["arg_order"] = idx - kernel_metadata_items.append(metadata) - - kernel_meta_info: Dict[str, Any] = {} - kernel_meta_info["meta_info"] = kernel_metadata_items - kernel_meta_info["kernel_path"] = ( - Path(kernel_lib_path).relative_to(persistent_cache).as_posix() - ) - - json_data = [] - update_json = True - op_conf = persistent_cache / f"{op_func_name_with_overload}.json" - mode = "r" if op_conf.exists() else "w" - with aoti_eager_op_conf_lock(op_func_name_with_overload): - with open(op_conf, mode) as op_conf_file: - try: - json_data = json.load(op_conf_file) - except Exception as e: - json_data = [] - - assert isinstance(json_data, list) - for item in json_data: - assert isinstance(item, dict) - # Same kernel meta info already exists in the json file - if item["meta_info"] == kernel_metadata_items: - update_json = False - break - - if update_json: - json_data.append(kernel_meta_info) - with open(op_conf, "w") as op_conf_file: - json.dump(json_data, op_conf_file, indent=4) - - return kernel_lib_path - except Exception as e: - return "" - - def run_and_get_cpp_code(fn, *args, **kwargs): # We use the patch context manager instead of using it as a decorator. # In this way, we can ensure that the attribute is patched and unpatched correctly diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index cda50f077e5723..f93b0fb2a9da19 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -235,10 +235,11 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { py::gil_scoped_acquire gil; py::handle load_aoti_eager_cache_function = - py::module::import("torch._inductor.utils").attr("load_aoti_eager_cache"); + py::module::import("torch._inductor.aoti_eager") + .attr("load_aoti_eager_cache"); TORCH_INTERNAL_ASSERT( load_aoti_eager_cache_function.ptr() != nullptr, - "Failed to import - torch._inductor.utils.load_aoti_eager_cache"); + "Failed to import - torch._inductor.aoti_eager.load_aoti_eager_cache"); auto result = py::reinterpret_steal(PyObject_CallFunctionObjArgs( load_aoti_eager_cache_function.ptr(), @@ -431,12 +432,12 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( overload_name); py::handle aot_compile_function = - py::module::import("torch._inductor.utils") + py::module::import("torch._inductor.aoti_eager") .attr("aoti_compile_with_persistent_cache"); TORCH_INTERNAL_ASSERT( aot_compile_function.ptr() != nullptr && aot_compile_function.ptr() != Py_None, - "Failed to import - torch._inductor.utils.aoti_compile_with_persistent_cache"); + "Failed to import - torch._inductor.aoti_eager.aoti_compile_with_persistent_cache"); // Pass the python operation to the AOT Inductor to generate the kernel // library. From f0d68120f4e99ee6c05f1235d9b42a4524af39d5 Mon Sep 17 00:00:00 2001 From: David Berard Date: Sat, 15 Jun 2024 04:50:59 +0000 Subject: [PATCH 060/171] [subclasses] Handle dynamo inputs that are subclass views with (-1) in the view (#128662) When handling an input to dynamo that's a view of a subclass, dynamo does some handling to reconstruct the view. Part of this is to construct symints for the input parameters to the view. Previously, the code would just call `create_symbol()` which by default specifies a _positive_ symint (>= 0); this fails in the case where you have an aten::view that was called with a -1. Fix: just specify `positive=None` when calling `create_symbol()`, to avoid restricting the symint to >= 0 or <= 0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128662 Approved by: https://github.com/jbschlosser --- test/dynamo/test_subclasses.py | 1 + ...utograd.test_output_aliases_multiple_inputs_get_correct_one | 0 torch/_subclasses/meta_utils.py | 3 ++- 3 files changed, 3 insertions(+), 1 deletion(-) delete mode 100644 test/dynamo_expected_failures/TestAOTAutograd.test_output_aliases_multiple_inputs_get_correct_one diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 954859f50994f4..302b07e4ddb78b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1392,6 +1392,7 @@ def _get_views(t): # returns (view: Tensor, expects_raises_false) yield t.select(-1, 6), False # https://github.com/pytorch/pytorch/issues/128649 yield t[2:3, 5:9], dynamic + yield t.view(-1, 15), False def f(x): return x * 2 diff --git a/test/dynamo_expected_failures/TestAOTAutograd.test_output_aliases_multiple_inputs_get_correct_one b/test/dynamo_expected_failures/TestAOTAutograd.test_output_aliases_multiple_inputs_get_correct_one deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 4ea0db56aae250..bfcd53f66ceb0b 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -964,7 +964,8 @@ def symint_visitor_fn(s): # assumption of it being simplified out will fail and it may be guarded on, # which will hard error. sym_source = EphemeralSource("symint_visitor_fn") - symbol = shape_env.create_symbol(s, sym_source) + + symbol = shape_env.create_symbol(s, sym_source, positive=None) return shape_env.create_symintnode(symbol, hint=s, source=sym_source) real_to_fake_mapping = {} From 94c0dcbe1d31d93cc9bc4760ba2f9ce80c093326 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 14 Jun 2024 16:52:32 -0700 Subject: [PATCH 061/171] [inductor] Parallel compile: handle crashes in subprocesses (#128757) Summary: If any subprocess in the pool crashes, we get a BrokenProcessPool exception and the whole pool becomes unusable. Handle crashes by recreating the pool. Test Plan: * New unit test * Started a long-running test (`test/inductor/test_torchinductor.py`), periodically killed subprocess manually, made sure the test run recovers and makes progress. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128757 Approved by: https://github.com/jansel --- test/inductor/test_compile_worker.py | 16 +++++++ torch/_inductor/compile_worker/__main__.py | 44 +++++++++++-------- .../_inductor/compile_worker/subproc_pool.py | 28 +++++++++--- 3 files changed, 63 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 6b6f10a4720b65..61ee25421847c6 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] import operator +import os from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, @@ -31,6 +32,21 @@ def test_exception(self): finally: pool.shutdown() + def test_crash(self): + pool = SubprocPool(2) + try: + with self.assertRaises(Exception): + a = pool.submit(os._exit, 1) + a.result() + + # Pool should still be usable after a crash + b = pool.submit(operator.add, 100, 1) + c = pool.submit(operator.sub, 100, 1) + self.assertEqual(b.result(), 101) + self.assertEqual(c.result(), 99) + finally: + pool.shutdown() + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py index 7f0965415bbff6..a343dc644b5413 100644 --- a/torch/_inductor/compile_worker/__main__.py +++ b/torch/_inductor/compile_worker/__main__.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import argparse +import logging import os import sys import typing @@ -9,6 +10,8 @@ from torch._inductor.compile_worker.watchdog import _async_compile_initializer from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path +log = logging.getLogger(__name__) + _set_triton_ptxas_path() try: @@ -20,25 +23,28 @@ def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--workers", type=int) - parser.add_argument("--parent", type=int) - args = parser.parse_args() - if os.getppid() != args.parent: - sys.exit(0) - write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb")) - read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb")) - - # nobody else should read stdin - sys.stdin.close() - - # redirect output of workers to stderr - os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) - - pre_fork_setup() - - _async_compile_initializer(args.parent) - SubprocMain(args.workers, read_fd, write_fd).main() + try: + parser = argparse.ArgumentParser() + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb")) + read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb")) + + # nobody else should read stdin + sys.stdin.close() + + # redirect output of workers to stderr + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + + pre_fork_setup() + + _async_compile_initializer(args.parent) + SubprocMain(args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") if __name__ == "__main__": diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 5aba18707b41fd..5bde798d88416e 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -11,6 +11,7 @@ import threading import typing from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool from typing import Any, Callable, Dict from torch._inductor import config @@ -180,16 +181,20 @@ def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe): self.read_pipe = read_pipe self.write_pipe = write_pipe self.write_lock = threading.Lock() - self.pool = ProcessPoolExecutor( + self.nprocs = nprocs + self.pool = self._new_pool(nprocs, True) + self.running = True + + def _new_pool(self, nprocs, warm): + pool = ProcessPoolExecutor( nprocs, mp_context=multiprocessing.get_context("fork"), initializer=functools.partial(_async_compile_initializer, os.getpid()), ) - multiprocessing.util.Finalize( - None, self.pool.shutdown, exitpriority=sys.maxsize - ) - self.running = True - _warm_process_pool(self.pool, nprocs) + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + if warm: + _warm_process_pool(pool, nprocs) + return pool def main(self): while True: @@ -210,6 +215,17 @@ def _shutdown(self): self.pool.shutdown() def submit(self, job_id, data): + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = self._new_pool(self.nprocs, False) + + def _submit_inner(self, job_id, data): future = self.pool.submit(functools.partial(SubprocMain.do_job, data)) def callback(_): From 6079c5091091d872b8dafbaa4e31a5b6194647ad Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sat, 15 Jun 2024 17:52:09 +0000 Subject: [PATCH 062/171] Make config.fx_graph_remote_cache be three-value switch (#128628) Summary: We want to allow for three configurations False: Force off True: Force on None: OFF for OSS and JK config for internal Test Plan: CI Differential Revision: D58535897 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128628 Approved by: https://github.com/masnesral, https://github.com/eellison --- test/inductor/test_codecache.py | 12 ++++++++++++ test/inductor/test_torchinductor.py | 1 + torch/_inductor/compile_fx.py | 4 ++-- torch/_inductor/config.py | 15 +++++++++++++-- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index e9f73cc14e04e7..4e1532f9d4c98e 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -109,6 +109,7 @@ def reset(self): @requires_triton() @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("dynamic", (False, True)) @@ -216,6 +217,7 @@ def put(self, filename, data): @requires_triton() @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("dtype", (torch.float32, torch.float64)) @parametrize("dynamic", (False, True)) @@ -255,6 +257,7 @@ def fn(mod, x): @largeTensorTest("64GB", device=GPU_TYPE) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) @parametrize("device", (GPU_TYPE,)) @parametrize("dtype", (torch.float16, torch.bfloat16)) def test_cache_load_with_guards_int32_bounds(self, device, dtype): @@ -303,6 +306,7 @@ def fn(x, y): self.assertEqual(res1, res2) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("dtype", (torch.float32, torch.bfloat16)) def test_cache_load_with_guards_static_bounds(self, device, dtype): @@ -346,6 +350,7 @@ def fn(x): self.assertEqual(res1, res2) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) @parametrize("device", (GPU_TYPE, "cpu")) def test_constant_handling(self, device): """ @@ -378,6 +383,7 @@ def fn2(x): @requires_gpu() @requires_triton() @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_higher_order_op_bypass(self): """ Verify that we bypass the cache when we have higher order ops. @@ -403,6 +409,7 @@ def fn(x, y): self.assertGreater(counters["inductor"]["fxgraph_cache_bypass"], 0) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_generated_kernel_count(self): """ Test that we bump the generated_kernel_count metric on a cache hit. @@ -431,6 +438,7 @@ def fn(x, y): self.assertEqual(metrics.generated_kernel_count, 2) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_cache_clear(self): """ Test clearing the cache. @@ -465,6 +473,7 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_cache_with_nt(self): def gen_nt(r): values = torch.randn(r, 16) @@ -493,6 +502,7 @@ def fn(nt): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_cache_with_symint_non_arg_guard(self): def fn(x, ref_id): self_id = 22 @@ -516,6 +526,7 @@ def fn(x, ref_id): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) def test_cache_guard(self): def f(x, val): if val > 5: @@ -740,6 +751,7 @@ def test_cuda_compile_command(self): class TestUtils(TestCase): + @config.patch({"fx_graph_remote_cache": False}) def test_fresh_inductor_cache(self): def fn(x, y): return x + y diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b8be175143a285..83954feb02900e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10682,6 +10682,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertEqual(fn_opt(*inps), fn(*inps)) + @config.patch({"fx_graph_remote_cache": False}) def test_optimize_indexing_dtype_with_constraint(self): def fn1(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: x = torch.arange(0, b.shape[0], device=GPU_TYPE) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b99090a450e0ed..2415178d9f8656 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -392,8 +392,8 @@ def fake_tensor_prop( def should_use_remote_fx_graph_cache(): - if config.fx_graph_remote_cache: - return True + if config.fx_graph_remote_cache is not None: + return config.fx_graph_remote_cache if not config.is_fbcode(): return False if torch.version.hip is not None: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 380e5ded907136..6be72fb8ea2045 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -10,6 +10,14 @@ def is_fbcode(): return not hasattr(torch.version, "git_version") +def fx_graph_remote_cache_default(): + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1": + return True + if os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "0": + return False + return None + + # add some debug printouts debug = False @@ -22,8 +30,11 @@ def is_fbcode(): # use fx aot graph codegen cache fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1" -# use fx aot graph codegen cache -fx_graph_remote_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") == "1" +# use remote fx aot graph codegen cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() # enable autotune local cache autotune_local_cache = True From ab13980424fea6817f18cc83dac4729743348de2 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Sat, 15 Jun 2024 23:42:13 +0000 Subject: [PATCH 063/171] [ONNX] Update 'person_of_interest.rst', 'CODEOWNERS' and 'merge_rules.yaml' (#126364) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The following are all constrained under the ONNX exporter project scope. - `personal_of_interest.rst` - Moving folks no longer working on the project to emeritus. - Adding @justinchuby, @titaiwangms, @shubhambhokare1 and @xadupre, who have all made countless contributions to this project. - `CODEOWNERS` - Removing folks no longer working on the project. - Updating new owners who will now be notified with PRs related to the specific file paths. - `merge_rules.yaml` - Removing folks no longer working on the project. 🫡 Co-authored-by: Justin Chu Pull Request resolved: https://github.com/pytorch/pytorch/pull/126364 Approved by: https://github.com/titaiwangms, https://github.com/justinchuby, https://github.com/albanD --- .github/merge_rules.yaml | 2 -- CODEOWNERS | 12 ++++++------ docs/source/community/persons_of_interest.rst | 10 +++++++--- torch/_dynamo/backends/onnxrt.py | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 837d36f3a07dcf..6c454e48245bbc 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -27,11 +27,9 @@ - third_party/onnx - caffe2/python/onnx/** approved_by: - - BowenBao - justinchuby - liqunfu - shubhambhokare1 - - thiagocrepaldi - titaiwangms - wschin - xadupre diff --git a/CODEOWNERS b/CODEOWNERS index 5c81bc7b246ef9..664664058b8f36 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -43,12 +43,12 @@ nn/qat/ @jerryzh168 /torch/csrc/distributed/rpc/tensorpipe_agent.h @jiayisuse @osalpekar @lw # ONNX Export -/torch/_dynamo/backends/onnxrt.py @bowenbao @thiagocrepaldi @wschin -/torch/csrc/jit/passes/onnx.h @bowenbao @thiagocrepaldi -/torch/csrc/jit/passes/onnx.cpp @bowenbao @thiagocrepaldi -/torch/csrc/jit/passes/onnx/ @bowenbao @thiagocrepaldi -/torch/onnx/ @bowenbao @thiagocrepaldi @wschin -/test/onnx/ @bowenbao @thiagocrepaldi @wschin +/torch/_dynamo/backends/onnxrt.py @wschin @xadupre +/torch/csrc/jit/passes/onnx.h @titaiwangms @shubhambhokare1 @xadupre +/torch/csrc/jit/passes/onnx.cpp @titaiwangms @shubhambhokare1 @xadupre +/torch/csrc/jit/passes/onnx/ @titaiwangms @shubhambhokare1 @xadupre +/torch/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin @xadupre +/test/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin @xadupre # CI /.ci @pytorch/pytorch-dev-infra diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index d5c4ff2e1aa62b..4fb4e146ed4627 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -226,9 +226,13 @@ C10 utils and operator dispatch ONNX exporter ~~~~~~~~~~~~~ -- Aaron Bockover (`abock `__) -- Bowen Bao (`BowenBao `__) -- Thiago Crepaldi (`thiagocrepaldi `__) +- Shubham Bhokare (`shubhambhokare1 `__) +- Justin Chu (`justinchuby `__) +- Xavier Dupré (`xadupre `__) +- Titai Wang (`titaiwangms `__) +- (emeritus) Bowen Bao (`BowenBao `__) +- (emeritus) Thiago Crepaldi (`thiagocrepaldi `__) +- (emeritus) Aaron Bockover (`abock `__) - (emeritus) Gary Miguel (`garymm `__) - (emeritus) Lara Haidar (`lara-hdr `__) - (emeritus) Lu Fang (`houseroad `__) diff --git a/torch/_dynamo/backends/onnxrt.py b/torch/_dynamo/backends/onnxrt.py index 91e69923e3124b..75b1e36abbb9a9 100644 --- a/torch/_dynamo/backends/onnxrt.py +++ b/torch/_dynamo/backends/onnxrt.py @@ -3,7 +3,7 @@ # This backend is maintained by ONNX team. To direct issues # to the right people, please tag related GitHub issues with `module: onnx`. # -# Maintainers' Github IDs: wschin, thiagocrepaldi, BowenBao +# Maintainers' Github IDs: wschin, xadupre from torch.onnx._internal.onnxruntime import ( is_onnxrt_backend_supported, torch_compile_backend, From a61939467a5669288ae89da781bce9645ddce1ae Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sun, 16 Jun 2024 07:28:09 +0000 Subject: [PATCH 064/171] Enable passing dynamo-traced complex test (#128771) Fixes https://github.com/pytorch/pytorch/issues/118159 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128771 Approved by: https://github.com/anijain2305 --- torch/testing/_internal/common_optimizers.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index ac4a7f920cc2e7..628bedad313dc1 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1323,11 +1323,6 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "test_tensor_lr", active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), ), - DecorateInfo( - skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"), - "TestOptimRenewed", - "test_complex", - ), DecorateInfo( skipIfTorchDynamo("See #116028"), "TestOptimRenewed", From f1ee3589a1277bdc9c580612868a92e7ba27e39a Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Sun, 16 Jun 2024 07:35:57 +0000 Subject: [PATCH 065/171] [Inductor] Emit strided block pointer from ModularIndexing and FloorDiv (#127342) **Summary** Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets. Because the 1D block size needs to map to an integer block shape in ND, we need to know that the ND block size evenly divides the size of the iteration range. This PR only generates ND block pointers when it can guarantee that the iteration order and number of elements loaded are unchanged. This means that the number of elements in a slice of the iteration range must either be: - Powers of 2. Since Triton block sizes are powers of 2, any integer power of 2 either divides the block size, or is greater than the block size. In the latter case, `CielDiv(x, y)` rounds up to 1. - Multiples of the maximum block size. Since block sizes are powers of 2, the maximum block size is a multiple of every possible block size. Note that a *slice* of the iteration range does not include the leading dimension. Thus we can support arbitrary leading dimensions like `(5,8)`. Feature proposal and discussion: https://github.com/pytorch/pytorch/issues/125077 Example kernel: ``` triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 4096 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK]) tmp1 = tmp0 + tmp0 tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32)) ''', device_str='cuda') ``` **Test Plan** This PR adds a new CI test script to cover this feature. The tests can be grouped into a few main categories: - Can we generate strided block pointers for the appropriate shapes? - Powers of 2 - Non-power of 2, but multiple of the maximum block size - Arbitrary leading dimensions, with power of 2 inner dimensions - Weird strides and offsets - Reductions - Symbolic shapes that are multiples of the maximum block size (wasn't able to trace this through dynamo) - Broadcasts (some variables are missing from the indexing expression) - Do we still compile other cases correctly, even if we don't expect to be able to generate block pointers? - Unsupported static shapes - Unsupported symbolic shapes - Mixing and matching these cases: - Pointwise and reduction in the same kernel - Sanity check the test harness - Do we raise an exception if the expected number of block pointers and the actual number are different? **Follow-ups** There are a few important cases which this PR can't handle. I'm hoping these can be deferred to follow-up PRs: - Handle non-divisible shapes - Change the tiling algorithm to generate a 2D (X,Y) blocking, if doing so enables block pointers to be emitted. - Pad unsupported loads up to the nearest divisible size, then mask/slice out the extra elements? This is probably the best solution, but I'm not yet sure how to go about it in triton. - Take advantage of this analysis when `triton.use_block_ptr=False`. I'm guessing we can still avoid `%` and `/` without requiring block pointers. Maybe we could compute block indices with arange and broadcast instead? Differential Revision: D56739375 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127342 Approved by: https://github.com/jansel, https://github.com/shunting314 --- test/inductor/test_codecache.py | 2 +- test/inductor/test_indexing.py | 12 +- test/inductor/test_torchinductor.py | 33 +- .../test_torchinductor_strided_blocks.py | 361 ++++++++++++++ torch/_inductor/codegen/triton.py | 471 +++++++++++++++--- torch/_inductor/runtime/runtime_utils.py | 5 + torch/_inductor/sizevars.py | 9 +- torch/testing/_internal/inductor_utils.py | 20 +- torch/utils/_sympy/functions.py | 18 +- torch/utils/_sympy/symbol.py | 5 +- 10 files changed, 834 insertions(+), 102 deletions(-) create mode 100644 test/inductor/test_torchinductor_strided_blocks.py diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 4e1532f9d4c98e..0f22e6bc0eaaec 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -36,6 +36,7 @@ HAS_CUDA, HAS_GPU, HAS_MULTIGPU, + requires_gpu, ) from torch.utils._triton import has_triton @@ -46,7 +47,6 @@ from torch.testing._internal.triton_utils import add_kernel -requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu") requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") torch._dynamo.config.fake_tensor_cache_enabled = True diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 19a736160908c4..90adc18e0f47c9 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -336,19 +336,21 @@ def test_print_floor_div(self): def test_print_Min_Max(self): cases = ( - (sympy.Min, "min"), - (sympy.Max, "max"), + (sympy.Min, "min", "<"), + (sympy.Max, "max", ">"), ) - for f, s in cases: + for f, s, cmp in cases: x = sympy.Symbol("x", integer=True) expr = f(-2, x) - self.assertEqual(texpr(expr), f"tl.{s}imum(-2, x)") + self.assertEqual( + texpr(expr), f"((-2) * ((-2) {cmp}= (x)) + (x) * ((x) {cmp} (-2)))" + ) self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)") expr = f(x, 2 * x, 3 * x) self.assertEqual( texpr(expr), - f"tl.{s}imum(x, tl.{s}imum(2*x, 3*x))", + f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", # noqa: B950 line too long ) self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})") diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 83954feb02900e..ea891f6ab74334 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -123,6 +123,7 @@ HAS_CPU, HAS_GPU, HAS_MULTIGPU, + requires_gpu, skipCPUIf, skipCUDAIf, ) @@ -130,7 +131,6 @@ HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines aten = torch.ops.aten -requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu") requires_multigpu = functools.partial( unittest.skipIf, not HAS_MULTIGPU, f"requires multiple {GPU_TYPE} devices" @@ -9030,12 +9030,13 @@ def forward(arg6, arg7, arg16): assertGeneratedKernelCountEqual(self, 0) @requires_gpu() + @parametrize("use_block_ptr", (False, True)) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware", ) @skipIfRocm - def test_sdpa(self): + def test_sdpa(self, use_block_ptr): def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): view = torch.ops.aten.view.default(arg3_1, [23760, 128]) arg3_1 = None @@ -9067,6 +9068,9 @@ def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): _scaled_dot_product_efficient_attention = None return (getitem,) + if self.device == "cpu": + raise unittest.SkipTest(f"requires {GPU_TYPE}") + DEVICE = torch.device(f"{GPU_TYPE}:0") DTYPE = torch.float16 B = 3 @@ -9082,13 +9086,22 @@ def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE) weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE) + inps = (query, key, value, bias, weights) - self.common( - foo, - (query, key, value, bias, weights), - atol=0.02, - rtol=1e4, - ) + with config.patch("triton.use_block_ptr", use_block_ptr): + # Check accuracy + self.common( + foo, + inps, + atol=0.02, + rtol=1e4, + ) + + # Check code for block pointers + foo_opt = torch._dynamo.optimize("inductor")(foo) + code = run_and_get_triton_code(foo_opt, *inps) + have_block_ptr = code.count("tl.make_block_ptr") > 0 + self.assertEqual(have_block_ptr, use_block_ptr) @requires_gpu() @unittest.skipIf( @@ -11045,8 +11058,8 @@ def f(a, b): self.assertExpectedInline( "\n".join(lines), """\ - tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) - tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", + tmp0 = tl.reshape(tl.load(block_ptr0, boundary_check=[3], padding_option='zero', eviction_policy='evict_last'), [XBLOCK, RBLOCK]) + tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long ) # Disable index propagation, so the indirect indexing isn't optimized away diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py new file mode 100644 index 00000000000000..bd859802892df7 --- /dev/null +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -0,0 +1,361 @@ +# Owner(s): ["module: inductor"] +import contextlib +import importlib +import unittest +from typing import Any, Callable, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._inductor import config +from torch._inductor.runtime.hints import TRITON_MAX_BLOCK +from torch._inductor.runtime.runtime_utils import is_power_of_2 +from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_GPU, + requires_gpu, + skip_windows_ci, +) + + +skip_windows_ci(__name__, __file__) + +importlib.import_module("filelock") + +max_block: int = TRITON_MAX_BLOCK["X"] + + +@requires_gpu() +@config.patch("triton.use_block_ptr", True) +@instantiate_parametrized_tests +class TritonBlockPointerTest(InductorTestCase): + def run_and_compare( + self, + func: Callable[..., Any], + *args, + compile_kwargs: Optional[dict] = None, + expected_num_block_pointers: Optional[int] = None, + expected_num_programs: int = 1, + expected_num_triton_kernels: int = 1, + ): + """ + Runs the module through Inductor, comparing to eager reference. + """ + if compile_kwargs is None: + compile_kwargs = {} + + def flatten_tensors(tensors): + flat, spec = pytree.tree_flatten(tensors) + return flat + + compiled = torch.compile(func, backend="inductor", **compile_kwargs) + result, code = run_and_get_code(compiled, *args) + + # Check numerical accuracy + ref_tensors = flatten_tensors(func(*args)) + actual_tensors = flatten_tensors(result) + for ref, actual in zip(ref_tensors, actual_tensors): + self.assertTrue(torch.allclose(ref, actual)) + + def count_code(substr: str, expected: Optional[int]): + count = sum(prog.count(substr) for prog in code) + if expected is not None: + self.assertEqual(count, expected) + + # Check the code + self.assertEqual(len(code), expected_num_programs) + count_code("@triton.jit", expected_num_triton_kernels) + count_code("tl.make_block_ptr", expected_num_block_pointers) + + return result, code + + @parametrize( + "expected_num_block_pointers,raises", + [ + (3, False), # This should pass + (9, True), # This should fail + ], + ) + def test_expected_num_block_pointers( + self, expected_num_block_pointers: int, raises: bool + ): + """ + Checks that the test harness verifies the number of block pointers correctly. + """ + + def foo(x, y): + return x + y + + device = torch.device(GPU_TYPE) + inputs = [torch.randn(8).to(device) for arg_idx in range(2)] + + # Expect failure for bad inputs + with self.assertRaises(AssertionError) if raises else contextlib.nullcontext(): + # Expect 3 block pointers: 2 inputs 1 output + self.run_and_compare( + foo, *inputs, expected_num_block_pointers=expected_num_block_pointers + ) + + @parametrize( + "full_size,view_size,stride,offset,require_block_ptr", + [ + ((64, 32, 32), (32, 16, 8), None, None, True), + ((16, 8, 8, 8), (8, 8, 4, 2), None, None, True), + ((8, 8, 8, 8), (4, 4, 4, 4), None, None, True), + ((8, 8), (4, 4), None, 10, True), # Storage offset + ((8, 8), (4, 4), (16, 2), None, True), # Non-default strides + ((8, 8), (4, 4), (1, 8), None, True), # Transposed strides + ( + (5, 9), + (5, 8), + None, + None, + True, + ), # Non-power-of-2 leading dim: block ptr + ( + (15, 9), + (15, 3), + None, + None, + False, + ), # Non-power-of-2 inner dims: non-block ptr + ((1, 1, 1), (1, 1, 1), None, None, False), # Scalar: non-block ptr + ( + (2, 4 * max_block), + (2, 3 * max_block), + None, + None, + True, + ), # Inner dim multiple of max_block + ], + ) + def test_pointwise( + self, + full_size: Tuple[int], + view_size: Tuple[int], + stride: Optional[Tuple[int]], + offset: Optional[int], + require_block_ptr: bool, + ): + """ + Test generating strided ND block pointers for a pointwise kernel. + + If require_block_ptr is True, the generated code must contain block + pointers. However, ND block pointers are not supported for all shapes. So + we also test some odd shapes with require_block_ptr set to False, to ensure that + block pointer analysis does not break these cases. + """ + + def get_input() -> torch.Tensor: + device = torch.device(GPU_TYPE) + full = torch.randn(full_size).to(device) + + # Use the original tensor's stride by default + view_stride = full.stride() if stride is None else stride + + return torch.as_strided(full, view_size, view_stride, storage_offset=offset) + + args = [get_input() for arg_idx in range(2)] + + # Expect 3 block pointers: 2 inputs 1 output + self.run_and_compare( + torch.add, + *args, + expected_num_block_pointers=3 if require_block_ptr else None, + ) + + @parametrize( + "x_size,y_size", + [ + ((8, 8), (8, 1)), + ((8, 8), (1, 8)), + ( + (4, 1, 4), + (1, 4, 1), + ), # Very important case: index variables are disjoint! + ( + (1, 1, 1, 4), + (4, 4, 4, 4), + ), # Unmatched dims for first operand. + ], + ) + def test_broadcast(self, x_size: Tuple[int], y_size: Tuple[int]): + """ + Test that we can generate strided block pointers when inputs have different + shapes, and they are broadcast together. + """ + + def foo(x, y): + a = x + 1 + b = y * 2 + return a + b + + def get_input(view_size: Tuple[int]) -> torch.Tensor: + device = torch.device(GPU_TYPE) + full_size = tuple(2 * dim for dim in view_size) + full = torch.randn(full_size).to(device) + view = torch.as_strided(full, view_size, full.stride()) + return view + + x, y = (get_input(size) for size in (x_size, y_size)) + + # Check that input sizes are not the same + self.assertNotEqual(x.shape, y.shape) + + # Check that at least one dimension is a singleton + all_dims = x.shape + y.shape + self.assertIn(1, all_dims) + + # Expect 3 block pointers: 2 inputs one output + self.run_and_compare(foo, x, y, expected_num_block_pointers=3) + + @parametrize( + "view_size,num_block_pointers,num_triton_kernels", + [ + ((4, 4), 1, 1), + ((4, 4, 4), 1, 1), + ((8, 8, 8), 1, 1), + ((15, 15), 0, 1), # Non-power of 2 + ((3 * max_block, 2), 3, 2), # Multiple of max block. Uses loops. + ( + (2, 3 * max_block), + 3, + 2, + ), # Multiple of max block. Uses loops. + ((128, 128), 3, 2), # Test a large size, with loops. + ], + ) + def test_reduction( + self, view_size: Tuple[int], num_block_pointers: int, num_triton_kernels: int + ): + """ + Tests a reduction kernel. + """ + + device = torch.device(GPU_TYPE) + full_size = tuple(2 * dim for dim in view_size) + full = torch.randn(full_size).to(device) + view = torch.as_strided(full, view_size, full.stride()) + + # Expect at least 1 block pointer for the input. + # Add 2 more if we generate 2 kernels. + result, (code,) = self.run_and_compare( + torch.sum, + view, + expected_num_block_pointers=num_block_pointers, + expected_num_triton_kernels=num_triton_kernels, + ) + + @parametrize( + "view_size,num_block_pointers,num_triton_kernels", + [ + ((8, 8), 2, 1), # No loops. Should be supported. + ( + (128, 128), + None, + None, + ), # Looped reduction. Block pointers not yet supported. + ], + ) + def test_mixed_pointwise_reduction( + self, view_size: Tuple[int], num_block_pointers: int, num_triton_kernels: int + ): + """ + Tests mixing pointwise with reduction ops. + """ + + def foo(x, y): + return torch.sum(x + y) + + device = torch.device(GPU_TYPE) + full_size = tuple(2 * dim for dim in view_size) + + def get_input() -> torch.Tensor: + full = torch.randn(full_size).to(device) + view = torch.as_strided(full, view_size, full.stride()) + return view + + inputs = [get_input() for input_idx in range(2)] + + # Expect 2 block pointers: inputs + result, (code,) = self.run_and_compare( + foo, + *inputs, + expected_num_block_pointers=num_block_pointers, + expected_num_triton_kernels=num_triton_kernels, + ) + + def test_multiple_max_block_non_power_of_2(self): + """ + Check that we support dims of size n * MAX_BLOCK, where n is any positive integer, not + necessarily a power of 2. + """ + + def foo(x): + return x - 1 + + device = torch.device(GPU_TYPE) + full_size = (3 * max_block, 3) + view_size = (3 * max_block, 2) + full = torch.randn(full_size).to(device) + view = torch.as_strided(full, view_size, full.stride()) + + # Check that we're using dims that aren't all powers of 2 + have_np2_dim = not all(is_power_of_2(dim) for dim in view_size) + self.assertTrue(have_np2_dim) + + # Check that we need more than one stride to represent the tensor + nontrivial_dims = [dim for dim in view_size if dim > 1] + self.assertTrue(len(nontrivial_dims) > 1) + + # Expect 2 block pointers: input and output + self.run_and_compare(foo, view, expected_num_block_pointers=2) + + def test_dynamic_shapes_generic(self): + """ + Test a generic strided block with dynamic shapes. Block pointers are not + expected. This only checks that the analysis doesn't break this case. + """ + + device = torch.device(GPU_TYPE) + full_size = (8, 8) + view_size = (4, 4) + full = torch.randn(full_size).to(device) + view = torch.as_strided(full, view_size, full.stride()) + + self.run_and_compare(torch.div, view, view, compile_kwargs={"dynamic": True}) + + @unittest.skip(reason="Dynamo tracing error") + def test_dynamic_shapes_multiple_max_block(self): + """ + Test dynamic shapes, where we know the shape is a multiple of the max block + size. We should be able to generate a block pointer for this case. + """ + + def foo(x): + tile_dims = (3 * max_block * x.shape[0], 3 * x.shape[1]) + view_size = (3 * max_block * x.shape[0], 2 * x.shape[1]) + full = x.tile(tile_dims) + view = torch.as_strided(full, view_size, full.stride()) + return view + view + + device = torch.device(GPU_TYPE) + x_size = (1, 1) + x = torch.randn(x_size).to(device) + + # Expect 2 block pointers: input and output + self.run_and_compare( + x, compile_kwargs={"dynamic": True}, expected_num_block_pointers=2 + ) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + if HAS_GPU: + run_tests(needs="filelock") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index eab10e2496ad47..381bbebf3acc5a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -13,6 +13,7 @@ Callable, cast, Dict, + Iterable, List, Optional, Set, @@ -29,8 +30,9 @@ from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties from torch._prims_common import is_integer_dtype +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._triton import has_triton_package -from ...utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir @@ -119,10 +121,21 @@ def gen_common_triton_imports(): return imports.getvalue() +block_offsets = { + symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True) + for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] +} + +block_sizes = { + symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, nonzero=True) + for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] +} + + @dataclasses.dataclass class IndexingOptions: index_str: str - mask_vars: Set[sympy.Symbol] + mask_vars: Set[str] mask_str: str expand_str: Optional[str] _has_rindex: bool @@ -146,29 +159,46 @@ def has_rmask(self): @dataclasses.dataclass class BlockPtrOptions: + params: BlockParameters constant_offset: sympy.Expr - shape: List[sympy.Expr] - strides: List[sympy.Expr] - block_shape: List[str] order: List[int] - offsets: List[str] - mask_vars: Set[sympy.Symbol] + mask_vars: Set[str] reshape_suffix: List[str] + @property + def shape(self) -> List[sympy.Expr]: + return self.params.shape + + @property + def block_shape(self) -> List[sympy.Expr]: + return self.params.block_shape + + @property + def strides(self) -> List[sympy.Expr]: + return self.params.strides + + @property + def offsets(self) -> List[sympy.Expr]: + return self.params.offsets + @staticmethod def create( - strides: List[sympy.Expr], + *, + params: BlockParameters, constant_offset: sympy.Expr, range_trees: List[IterationRangesEntry], - mask_vars: Set[sympy.Symbol], + mask_vars: Set[str], ) -> BlockPtrOptions: """Helper to create a BlockPtrOptions instance""" - block_shape = [f"{t.prefix.upper()}BLOCK" for t in range_trees] - reshape_suffix = [*block_shape] + reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees] + + # Only drop broadcast dims if the output has the same + # rank as the block. Otherwise, we will get shape errors. + drop_broadcasts = len(reshape_suffix) == len(params.strides) - broadcasting_dim = [s == 0 for s in strides] + broadcasting_dim = [s == 0 for s in params.strides] for i, is_broadcasting in enumerate(broadcasting_dim): - if is_broadcasting: + if is_broadcasting and drop_broadcasts: # drop any stride==0 dimensions for performance reshape_suffix[i] = "1" @@ -178,7 +208,7 @@ def create( if ( not V.kernel.inside_reduction - and len(strides) == len(V.kernel.numels) - 1 + and len(params.strides) == len(V.kernel.numels) - 1 and V.kernel.numels[-1] != 1 ): # Need to expand rank by 1 to match rank when self.inside_reduction=True @@ -190,23 +220,36 @@ def filter(it): return [ item for item, is_broadcasting in zip(it, broadcasting_dim) - if not is_broadcasting + if not is_broadcasting or not drop_broadcasts ] + # Drop broadcasting dimensions from the input. + params = BlockParameters( + **{key: filter(val) for key, val in dataclasses.asdict(params).items()} + ) + + def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]: + return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs] + + # Look up precomputed sizes + params.shape = lookup_size(params.shape) + params.strides = lookup_size(params.strides) + return BlockPtrOptions( + params=params, constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), - shape=[ - V.graph.sizevars.lookup_precomputed_size(t.numel) - for t in filter(range_trees) - ], - strides=[*map(V.graph.sizevars.lookup_precomputed_size, filter(strides))], - block_shape=filter(block_shape), - order=V.graph.sizevars.guarded_order(filter(strides)), - offsets=filter([f"{t.prefix}offset" for t in range_trees]), + order=list(reversed(range(len(params.shape)))), mask_vars=mask_vars, reshape_suffix=reshape_suffix, ) + def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr: + """ + Replaces instances of roffset with the new expression. + """ + roffset = block_offsets[SymT.RINDEX] + return sympy_subs(expr, {roffset: replacement}) + def format(self, name: str, roffset=True) -> str: """ Codegen a call to tl.make_block_ptr() @@ -221,7 +264,9 @@ def format(self, name: str, roffset=True) -> str: f = V.kernel.index_to_str offsets = [*self.offsets] if not roffset: - offsets[offsets.index("roffset")] = "0" + offsets = [ + self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets + ] args = [ f"{name} + ({f(self.constant_offset)})" if self.constant_offset != 0 @@ -237,31 +282,59 @@ def format(self, name: str, roffset=True) -> str: @cache_on_self def boundary_check(self) -> List[int]: """List of indices to pass to tl.load(boundary_check=...)""" - check = [] - for i in range(len(self.shape)): + sizevars = V.graph.sizevars + + # Substitute maximum block sizes in shape expressions. + # This works in multiple_of checks because block sizes are powers of 2. + block_to_max: Dict[sympy.Expr, Any] = { + block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()] + for symt, block_size in block_sizes.items() + } + + return [ + idx + for idx in range(len(self.shape)) if ( - self.block_shape[i] != "1" - and not V.graph.sizevars.statically_known_equals(self.strides[i], 0) # type: ignore[arg-type] - and not V.graph.sizevars.statically_known_multiple_of( - self.shape[i], - TRITON_MAX_BLOCK[self.block_shape[i][0]], # type: ignore[arg-type] + not sizevars.statically_known_equals( + self.strides[idx], sympy.Integer(0) ) - and not (V.kernel.no_x_dim and self.block_shape[i] == "XBLOCK") - ): - check.append(i) - return check + and not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], sympy_subs(self.block_shape[idx], block_to_max) + ) + and not ( + V.kernel.no_x_dim + and self.block_shape[idx] == block_sizes[SymT.XBLOCK] + ) + ) + ] def advance_roffset(self): - """Codegen string to pass to tl.advance(name, ...)""" - advance = ["0"] * len(self.shape) - advance[self.offsets.index("roffset")] = "RBLOCK" + """ + Codegen string to pass to tl.advance(name, ...). + + Advance is the difference between offsets in each loop iteration. + To compute it, we replace roffset with multiples of RBLOCK. + Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first + iteration has roffset=0, while the second has roffset=RBLOCK. + """ + rblock = block_sizes[SymT.RINDEX] + advance = [ + ( + self.replace_roffset(offset, rblock) + - self.replace_roffset(offset, sympy.Integer(0)) + ) + for offset in self.offsets + ] return V.kernel.index_to_str(advance) def has_indirect(self): return False # block_ptr can't do indirect indexing - def has_rindex(self): - return "RBLOCK" in self.block_shape + def has_rindex(self) -> bool: + return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape) def has_rmask(self): return self.has_rindex() @@ -365,26 +438,31 @@ def _print_Where(self, expr): q = self.doprint(expr.args[2]) return f"tl.where({c}, {p}, {q})" - def _print_Min(self, expr): + def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: + """ + Helper for max/min code genereration. + cmp: > or < + """ nargs = len(expr.args) if len(expr.args) == 1: return self._print(expr.args[0]) mid = len(expr.args) // 2 - a = self._print(sympy.Min(*expr.args[:mid])) - b = self._print(sympy.Min(*expr.args[mid:])) - return f"tl.minimum({a}, {b})" + cls = type(expr) + a = self._print(cls(*expr.args[:mid])) + b = self._print(cls(*expr.args[mid:])) - def _print_Max(self, expr): - nargs = len(expr.args) - if len(expr.args) == 1: - return self._print(expr.args[0]) + # Use a macro so we can propagate constexprs. + # https://github.com/triton-lang/triton/issues/3815 + a, b = tuple(f"({x})" for x in (a, b)) + assert cmp in {">", "<"}, f"Unexpected comparator: '{cmp}'" + return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" - mid = len(expr.args) // 2 - a = self._print(sympy.Max(*expr.args[:mid])) - b = self._print(sympy.Max(*expr.args[mid:])) + def _print_Min(self, expr): + return self._print_min_max_helper(expr, "<") - return f"tl.maximum({a}, {b})" + def _print_Max(self, expr): + return self._print_min_max_helper(expr, ">") def _print_Abs(self, expr): assert len(expr.args) == 1 @@ -1023,6 +1101,26 @@ def __getitem__(self, idx): return self.finalized_helpers[idx] +@dataclasses.dataclass +class BlockParameters: + """ + Class representing ND block dimensions, for block pointer analysis. + """ + + shape: List[sympy.Expr] = dataclasses.field(default_factory=list) + block_shape: List[sympy.Expr] = dataclasses.field(default_factory=list) + strides: List[sympy.Expr] = dataclasses.field(default_factory=list) + offsets: List[sympy.Expr] = dataclasses.field(default_factory=list) + + def __add__(self, other: BlockParameters) -> BlockParameters: + """ + Concatenates block parameters. + """ + cls = type(self) + a, b = tuple(dataclasses.asdict(x) for x in (self, other)) + return cls(**{key: a[key] + b[key] for key in a}) + + class TritonKernel(SIMDKernel): overrides = TritonKernelOverrides # type: ignore[assignment] helper_functions: HelperFunctions @@ -1059,6 +1157,19 @@ def __init__( self.codegen_range_tree() + def _get_symt(self, tree: IterationRangesEntry) -> SymT: + prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} + return prefix_to_symt[tree.prefix] + + def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol: + return block_sizes[self._get_symt(tree)] + + def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol: + return block_offsets[self._get_symt(tree)] + + def _max_block_size(self, tree: IterationRangesEntry) -> int: + return TRITON_MAX_BLOCK[tree.prefix.upper()] + def codegen_range_tree(self): for tree in self.range_trees: # reduction indexing goes inside a loop @@ -1187,27 +1298,241 @@ def indexing( # workaround https://github.com/openai/triton/issues/2821 and self.index_dtype == "tl.int32" ): - index_relative_to_xyr_index = sympy_subs( - index, {v: t.expr for v, t in self.range_tree_nodes.items()} - ) - range_trees = self.active_range_trees(reorder=True) - symbols = [t.symbol() for t in range_trees] - strides = [sympy.Wild(f"stride_{s}", exclude=symbols) for s in symbols] - offset = sympy.Wild("_offset", exclude=symbols) - m = index_relative_to_xyr_index.match(sympy_dot(symbols, strides) + offset) - # TODO(jansel): it is sometimes possible to do higher dimensional block_ptrs with - # a tl.reshape the correct block. We will miss these cases today. - if m: - self.filter_masks(mask_vars) - from .triton import BlockPtrOptions + def match_strided_block( + index: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Matches expressions of the form: + idx = s * xindex + + This implies stride (s,), and shape (XBLOCK,). + """ + symbol = range_tree.symbol() + stride = sympy.Wild("stride", exclude=[symbol]) + m = index.match(symbol * stride) + if m is None: + return None + + return BlockParameters( + shape=[range_tree.numel], + block_shape=[self._get_block_size(range_tree)], + strides=[m[stride]], + offsets=[self._get_block_offset(range_tree)], + ) + + def match_mod_div_block( + index: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. + + Example expression to match: + sN * ((rindex//(d1 * ... * d(N-1)))) + + s1 * ModularIndexing(rindex, 1, d1) + + ... + + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) + + This iterates over a block of shape (dN, ..., d1) and stride + (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are + wildcards that we match. + + Note that dN does not appear in the expression, but we solve for it + using range tree numels and the other dims. + """ + # Bound the possible number of dims. We use the following heuristics: + # - At least one dim for each range tree node. + # - At least one dim for every FloorDiv or ModularIndexing op. + # - At least 2 dims to pattern match. + num_dims = max( + 2, + len(self.range_tree_nodes), + (index.count(FloorDiv) + index.count(ModularIndexing)), + ) + + # Pattern match to find the strides and offset. + index_var = range_tree.symbol() + wild = functools.partial(sympy.Wild, exclude=[index_var]) + dims: List[sympy.Expr] = [ + wild(f"dim_mod{idx}") for idx in range(num_dims) + ] + strides: List[sympy.Expr] = [ + wild(f"stride_mod{idx}") for idx in range(num_dims) + ] + + def get_slice_numels(dims: List[Any]) -> List[Any]: + """ + Compute the cumulative size of each dimension's slice. + This proceeds from the last dim up to the second. + """ + numels = [sympy.Integer(1)] + for dim in dims[:0:-1]: + numel = dim * numels[0] + numels.insert(0, numel) + return numels + + # The first dimension's index is computed by division. + # The remaining are computed by modulo. + slice_numels = get_slice_numels(dims[:num_dims]) + block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ + ModularIndexing(index_var, numel, dim) + for dim, numel in zip(dims[1:], slice_numels[1:]) + ] + + # Calculate a linear index from block indices. + match_expr = sympy_dot(strides, block_index_exprs) + + # Pattern match. + match = index.match(match_expr) + if match is None: + return None + + # Provide default values for unmatched dims and strides. + for dim in dims[1:]: + if dim not in match: + match[dim] = sympy.Integer(1) + for stride in strides[1:]: + if stride not in match: + match[stride] = sympy.Integer(0) + + sizevars = V.graph.sizevars + + def get_match(expr: sympy.Expr) -> sympy.Expr: + return sizevars.lookup_precomputed_size(match[expr]) + + # Replace wildcards with matched expressions. + dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] + strides = [get_match(stride) for stride in strides] + slice_numels = get_slice_numels(dims) + block_index_exprs = [ + sympy_subs(expr, match) for expr in block_index_exprs + ] + + # The leading dimension is not directly matched in our expression. + # We solve for it by dividing the range tree numel by the product of + # all other dimensions. We quit if they are not known to be divisible. + assert ( + dims[0] not in match + ), "Expected not to match the leading dimension!" + if not sizevars.statically_known_multiple_of( + range_tree.numel, slice_numels[0] + ): + return None + dims[0] = range_tree.numel / slice_numels[0] + + # Check for applicable iteration range sizes. + # When mapping a 1D block into an ND one, we need to know that + # the number of elements is not changed. This means the slice numels of + # the ND iteration range must evenly divide the length of the 1D block. + # There are two cases where we can guarantee this: + # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, + # with n and m integers, then either numel is a multiple of XBLOCK, or numel + # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) + # 2. Numels are multiples of the maximum possible block size. + max_block = self._max_block_size(range_tree) + if any( + not sizevars.statically_known_multiple_of(numel, max_block) + and not sizevars.statically_known_power_of_2(numel) + for numel in slice_numels + ): + return None + + def identity(expr: sympy.Expr) -> sympy.Expr: + return expr + + # Compute the ND block shape from the linear block size. + # Use CielDiv to round leading dimensions up to 1. + # Non-leading dimensions are clamped to the size of the iteration range, + # while the leading dimension can exceed this to accomodate a larger + # block size. + linear_block_size = self._get_block_size(range_tree) + block_shape: List[sympy.Expr] = [ + CeilDiv(linear_block_size, slice_numels[0]) + ] + [ + sympy.Min(CeilDiv(linear_block_size, numel), dim) + for numel, dim in zip(slice_numels[1:], dims[1:]) + ] + + # Compute block offsets from {xyzr}offset and the matched expressions. + block_offsets: List[sympy.Expr] = [ + sympy_subs(expr, {index_var: self._get_block_offset(range_tree)}) + for expr in block_index_exprs + ] + + return BlockParameters( + shape=dims, + block_shape=block_shape, + strides=strides, + offsets=block_offsets, + ) + + def match_block_pointer_subexpr( + expr: sympy.Expr, range_tree: IterationRangesEntry + ) -> Optional[BlockParameters]: + """ + Match a block indexing subexpression involving a single range tree. + """ + for match_func in ( + match_strided_block, + match_mod_div_block, + ): + match = match_func(expr, range_tree) + if match is not None: + return match + + return None + + def match_block_pointer() -> Optional[BlockPtrOptions]: + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees(reorder=True) + + # Match each range tree separately. + range_symbols = {tree.symbol() for tree in range_trees} + index_terms = sympy.Add.make_args(index_relative_to_xyr_index) + block_params = BlockParameters() + for tree in range_trees: + # Partition the index into subexpressions pertaining to each range tree. + # For example xindex * 5 + rindex * 3 is partitioned to + # (xindex * 5, rindex * 3). + symbol = tree.symbol() + subexpr = sympy.Integer(0) + sum( + expr for expr in index_terms if symbol in expr.free_symbols + ) + + # Reject mixed terms, e.g. xindex * rindex. + # NB: the zero expression is allowed, for broadcasting. + if len(range_symbols.intersection(subexpr.free_symbols)) > 1: + return None + + # Match the subexpression for this range tree. + params = match_block_pointer_subexpr(subexpr, tree) + if params is None: + return None + block_params += params + + # Collect leftover terms as a constant offset. + offset = sum( + expr + for expr in index_terms + if not range_symbols.intersection(expr.free_symbols) + ) + + # Form the block pointer. + self.filter_masks(mask_vars) return BlockPtrOptions.create( - [m[s] for s in strides], - m[offset], - range_trees, - mask_vars, # type: ignore[arg-type] + params=block_params, + constant_offset=offset, + range_trees=range_trees, + mask_vars=mask_vars, ) + # Return a block pointer, if indexing matches the pattern. + options = match_block_pointer() + if options is not None: + return options + expand_str = None index_str = self.index_to_str(index) if isinstance(index, sympy.Integer): @@ -1274,7 +1599,8 @@ def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=" f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})" ) # drop any extra size=1 dimensions - value = triton_reshape(value, indexing.reshape_suffix, indexing.block_shape) + block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape] + value = triton_reshape(value, indexing.reshape_suffix, block_shape) # workaround https://github.com/openai/triton/issues/2814 value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" return f"tl.store({block_ptr}, {value}{other})" @@ -1381,9 +1707,8 @@ def load(self, name: str, index: sympy.Expr): ) line = f"tl.load({block_ptr}{other}{ep})" # add needed size=1 dimensions - line = triton_reshape( - line, indexing.block_shape, indexing.reshape_suffix - ) + block_shape = [str(dim) for dim in indexing.block_shape] + line = triton_reshape(line, block_shape, indexing.reshape_suffix) elif isinstance(original_index, sympy.Integer): line = f"tl.load({var} + ({original_index}))" append_broadcast = indexing.expand_str diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index bfd6b6e5fb26ea..b458e554a7bef9 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -21,6 +21,11 @@ def ceildiv(numer: int, denom: int) -> int: return -(numer // -denom) +def is_power_of_2(n: int) -> bool: + """Returns whether n = 2 ** m for some integer m.""" + return n > 0 and n & n - 1 == 0 + + def next_power_of_2(n: int) -> int: """Return the smallest power of 2 greater than or equal to n""" n -= 1 diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index f48c0884d3ada9..2019415520442f 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -23,6 +23,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy +from .runtime.runtime_utils import is_power_of_2 from .utils import ( sympy_index_symbol, sympy_index_symbol_with_prefix, @@ -354,6 +355,13 @@ def statically_known_multiple_of( expr = sympy.Eq(numerator % denominator, 0) return self.is_expr_static_and_true(expr) # type: ignore[arg-type] + # See Note - [On Statically Known] + def statically_known_power_of_2(self, expr: Expr) -> bool: + """ + Returns a bool indicating if x is known to be a power of 2. + """ + return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) + # The guard functions require you to ALREADY KNOW that a particular # condition holds. If you don't know (you want to guard on an expression # being a particular value, and then get access to that value), use @@ -376,7 +384,6 @@ def guard_lt(self, left: Expr, right: Expr) -> None: def guarded_order(self, seq): """ Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing. - Used for generating block_ptrs. """ seq = [*map(self.remove_precomputed_replacements, seq)] seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)] diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 1078a189f69c27..3b7440c48aa515 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -4,7 +4,9 @@ import re import unittest import functools +import os from subprocess import CalledProcessError +import sys import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._inductor.codecache import CppCodeCache from torch.utils._triton import has_triton @@ -12,7 +14,11 @@ LazyVal, IS_FBCODE, ) -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, + IS_CI, + IS_WINDOWS, +) def test_cpu(): try: @@ -79,6 +85,18 @@ def decorate_fn(fn): return decorate_fn +def skip_windows_ci(name: str, file: str) -> None: + if IS_WINDOWS and IS_CI: + module = os.path.basename(file).strip(".py") + sys.stderr.write( + f"Windows CI does not have necessary dependencies for {module} tests yet\n" + ) + if name == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + +requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu") + skipCUDAIf = functools.partial(skipDeviceIf, device="cuda") skipXPUIf = functools.partial(skipDeviceIf, device="xpu") skipCPUIf = functools.partial(skipDeviceIf, device="cpu") diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 3c845f58117bc2..8322994e14cf44 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -143,15 +143,15 @@ def eval(cls, base, divisor): if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - # gcd in sympy is over polynomials, so you'll end up with rationals if - # you do this. Don't. - """ - if isinstance(base, sympy.Add): - for a in base.args: - gcd = sympy.gcd(a, divisor) - if gcd == divisor: - return FloorDiv(base - a, divisor) + a / gcd - """ + # Expands (x + y) // b into x // b + y // b. + # This only works if floor is an identity, i.e. x / b is an integer. + for term in sympy.Add.make_args(base): + quotient = term / divisor + if quotient.is_integer and isinstance(divisor, sympy.Integer): + # NB: this is correct even if the divisor is not an integer, but it + # creates rational expressions that cause problems with dynamic + # shapes. + return FloorDiv(base - term, divisor) + quotient try: gcd = sympy.gcd(base, divisor) diff --git a/torch/utils/_sympy/symbol.py b/torch/utils/_sympy/symbol.py index bd853faee6d2cd..565d6b3661a567 100644 --- a/torch/utils/_sympy/symbol.py +++ b/torch/utils/_sympy/symbol.py @@ -82,10 +82,11 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: # that it contains Basic, rather than Symbol def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool: assert isinstance(sym, sympy.Symbol) + name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK if isinstance(prefix, SymT): - return sym.name.startswith(prefix_str[prefix]) + return name_str.startswith(prefix_str[prefix]) else: - return sym.name.startswith(tuple(prefix_str[p] for p in prefix)) + return name_str.startswith(tuple(prefix_str[p] for p in prefix)) def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool: From cc518ebd381ad3c061ee687092648d6d3bb53714 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Sat, 15 Jun 2024 11:16:33 -0700 Subject: [PATCH 066/171] [Inductor Intel GPU backend Upstream] Reuse inductor test for Intel GPU (PART 2) (#124147) Reuse Inductor test case for Intel GPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124147 Approved by: https://github.com/EikanWang, https://github.com/jansel --- test/inductor/test_binary_folding.py | 18 +- test/inductor/test_control_flow.py | 76 ++--- .../inductor/test_coordinate_descent_tuner.py | 6 +- test/inductor/test_debug_trace.py | 9 +- test/inductor/test_dependencies.py | 8 +- test/inductor/test_indexing.py | 8 +- test/inductor/test_minifier.py | 27 +- test/inductor/test_mmdecomp.py | 30 +- test/inductor/test_smoke.py | 18 +- test/inductor/test_split_cat_fx_passes.py | 12 +- test/inductor/test_torchinductor.py | 6 +- ...st_torchinductor_codegen_dynamic_shapes.py | 262 ++++++++++-------- .../test_torchinductor_dynamic_shapes.py | 21 +- test/inductor/test_triton_kernels.py | 214 +++++++------- torch/testing/_internal/common_device_type.py | 11 +- torch/testing/_internal/common_utils.py | 2 +- torch/testing/_internal/opinfo/core.py | 17 +- torch/testing/_internal/triton_utils.py | 3 +- 18 files changed, 405 insertions(+), 343 deletions(-) diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index 1a25e81ebf279e..a5495c8905cec5 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -27,12 +27,12 @@ raise unittest.SkipTest("requires sympy/functorch/filelock") from inductor.test_inductor_freezing import TestCase -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests +from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU aten = torch.ops.aten @@ -243,14 +243,14 @@ class FreezingCpuTests(TestCase): copy_tests(BinaryFoldingTemplate, FreezingCpuTests, "cpu") -if HAS_CUDA and not TEST_WITH_ASAN: +if HAS_GPU and not TEST_WITH_ASAN: - class FreezingCudaTests(TestCase): - common = check_model_cuda - device = "cuda" - autocast = torch.cuda.amp.autocast + class FreezingGpuTests(TestCase): + common = check_model_gpu + device = GPU_TYPE + autocast = torch.amp.autocast(device_type=GPU_TYPE) - copy_tests(BinaryFoldingTemplate, FreezingCudaTests, "cuda") + copy_tests(BinaryFoldingTemplate, FreezingGpuTests, GPU_TYPE) del BinaryFoldingTemplate @@ -258,5 +258,5 @@ class FreezingCudaTests(TestCase): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index a91f776b2b8018..cc3b211676fc2a 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -9,8 +9,8 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -from torch.testing._internal.triton_utils import requires_cuda +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU +from torch.testing._internal.triton_utils import requires_gpu def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1): @@ -206,8 +206,8 @@ def _run_test( self.assertEqual(cnt.frame_count, 1, "only one compilation expected") - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_simple_control_flow(self, device, dynamic): # cond control flow without nesting @@ -221,8 +221,8 @@ def test_cond_simple_control_flow(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_nested_control_flow(self, device, dynamic): # cond control flow with nesting @@ -238,8 +238,8 @@ def test_cond_nested_control_flow(self, device, dynamic): num_predicates=3, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_outer_code_before_after(self, device, dynamic): # some code before and after the conditional @@ -253,8 +253,8 @@ def test_cond_outer_code_before_after(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_multiple_outputs(self, device, dynamic): # multiple outputs with different shapes @@ -269,8 +269,8 @@ def test_cond_multiple_outputs(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) def test_cond_advanced_dynamic_shapes(self, device): # subgraphs input shapes include symbolic expressions class Model(torch.nn.Module): @@ -297,7 +297,7 @@ def false_fn(x, y): dynamic=True, ) - @requires_cuda + @requires_gpu def test_cond_use_buffers_from_outer_scope(self): # subgraphs input shapes include symbolic expressions self._run_test( @@ -307,11 +307,11 @@ def test_cond_use_buffers_from_outer_scope(self): torch.randn(10, 20), torch.randn(10, 20), ), - device="cuda", + device=GPU_TYPE, dynamic=False, ) - @requires_cuda + @requires_gpu def test_cond_reintepret_view_inputs_outputs(self): # ReinterpretView in inputs and outputs of the subgraphs self._run_test( @@ -320,12 +320,12 @@ def test_cond_reintepret_view_inputs_outputs(self): torch.randn(10, 20), torch.randn(10, 20), ), - device="cuda", + device=GPU_TYPE, dynamic=True, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_subgraphs_with_parameters(self, device, dynamic): # nested Modules with parameters @@ -336,8 +336,8 @@ def test_cond_subgraphs_with_parameters(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_cond_non_tensor_predicates(self, device, dynamic): # model with a boolean predicate @@ -354,7 +354,7 @@ def test_cond_non_tensor_predicates(self, device, dynamic): num_predicates=0, ) - @requires_cuda + @requires_gpu def test_cond_aliasing_outputs(self): # output aliasing in subgraphs: not supported class Model(torch.nn.Module): @@ -377,8 +377,8 @@ def false_fn(x, y): torch.randn(10, 20), ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) def test_cond_decompose_ops_in_subgraph(self, device): class Model(torch.nn.Module): def forward(self, p, a): @@ -398,8 +398,8 @@ def false_fn(x): device=device, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) def test_cond_decompose_ops_in_subgraph_recursive(self, device): def inner_fn1(x): return torch.zeros_like(x) @@ -425,7 +425,7 @@ def false_fn(x): device=device, ) - @requires_cuda + @requires_gpu def test_cond_inductor_fx_passes_recursively_applied(self): counters = {"pre_grad": 0, "post_grad": 0} @@ -450,7 +450,7 @@ def post_grad_pass_counter(gm): torch.randn(10, 20), torch.randn(10, 20), ), - device="cuda", + device=GPU_TYPE, dynamic=True, num_predicates=3, ) @@ -584,8 +584,8 @@ def _run_test( self.assertEqual(cnt.frame_count, 1, "only one compilation expected") - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_while_loop_simple_control_flow(self, device, dynamic): # while_loop control flow without nesting @@ -599,8 +599,8 @@ def test_while_loop_simple_control_flow(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_while_loop_nested_control_flow(self, device, dynamic): # while_loop control flow with nesting @@ -615,8 +615,8 @@ def test_while_loop_nested_control_flow(self, device, dynamic): num_counters=2, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_while_loop_with_outer_code(self, device, dynamic): # while_loop control flow with outer code @@ -630,8 +630,8 @@ def test_while_loop_with_outer_code(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) @parametrize("dynamic", [False, True]) def test_while_loop_with_parameters(self, device, dynamic): # while_loop control flow with parameters @@ -642,8 +642,8 @@ def test_while_loop_with_parameters(self, device, dynamic): dynamic=dynamic, ) - @requires_cuda - @parametrize("device", ["cpu", "cuda"]) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) # dynamic=True doesn't work now due to # https://github.com/pytorch/pytorch/issues/123596 @parametrize("dynamic", [False]) @@ -667,5 +667,5 @@ def test_while_loop_with_outer_buffers(self, device, dynamic): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index fdd3abb1439275..0f890a63fa92c9 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -9,7 +9,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU try: import triton @@ -89,7 +89,7 @@ def f(x): with mock.patch.object( CoordescTuner, "compare_config", mock_compare_config_prefer_larger_XBLOCK ): - x = torch.ones(2, 256).cuda() + x = torch.ones(2, 256).to(GPU_TYPE) expected = f(x) # the first call get correct result when cache miss. Don't know why yet _ = torch.compile(f)(x) @@ -113,5 +113,5 @@ def test_value_too_large(self): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_GPU: run_tests() diff --git a/test/inductor/test_debug_trace.py b/test/inductor/test_debug_trace.py index cca6987610554d..3d11af6d995f24 100644 --- a/test/inductor/test_debug_trace.py +++ b/test/inductor/test_debug_trace.py @@ -9,8 +9,7 @@ import torch from torch._inductor import config, test_operators -from torch.testing._internal.common_cuda import TEST_CUDA -from torch.utils._triton import has_triton +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU try: try: @@ -170,7 +169,7 @@ def body(self, ops): # intentionally only cleanup on success so debugging test is easier shutil.rmtree(filename) - @unittest.skipIf(not TEST_CUDA or not has_triton(), "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires GPU") def test_debug_multi_tempalte(self): class ToyModel(torch.nn.Module): def __init__(self): @@ -188,9 +187,9 @@ def forward(self, x): with self.assertLogs( logging.getLogger("torch._inductor.debug"), level=logging.WARNING ), fresh_inductor_cache(): - m = ToyModel().to(device="cuda:0") + m = ToyModel().to(device=GPU_TYPE) m = torch.compile(m, mode="max-autotune") - input_tensor = torch.randn(100).to(device="cuda:0") + input_tensor = torch.randn(100).to(device=GPU_TYPE) m(input_tensor) diff --git a/test/inductor/test_dependencies.py b/test/inductor/test_dependencies.py index f0aefed2c44270..24d8192844f21d 100644 --- a/test/inductor/test_dependencies.py +++ b/test/inductor/test_dependencies.py @@ -10,12 +10,12 @@ from torch._inductor.utils import sympy_index_symbol from torch._inductor.virtualized import ops, V -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU class TestDependencies(InductorTestCase): def _create_buffer(self, name, shape, dtype=torch.float32): - return Buffer(name, FixedLayout(torch.device("cuda:0"), dtype, shape)) + return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape)) def setUp(self): super().setUp() @@ -48,7 +48,7 @@ def inner_fn(index): ) pointwise = Pointwise.create( - device=torch.device("cuda:0"), + device=torch.device(GPU_TYPE), dtype=torch.int32, inner_fn=inner_fn, ranges=[1024 * 4], @@ -133,5 +133,5 @@ def test_normalize_with_stride_order_unequal(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU and HAS_CUDA: + if HAS_CPU and HAS_GPU: run_tests("sympy") diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 90adc18e0f47c9..61194f52b0a3eb 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -18,7 +18,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.utils._sympy.functions import ( FloorDiv, ModularIndexing, @@ -220,7 +220,7 @@ def test_expand_floor_div_applied(self): expected = FloorDiv(x * 15 + y, 3) self.assertEqual(expected, FloorDiv(actual, denominator)) - @unittest.skipUnless(HAS_CUDA, "Need GPU for this test") + @unittest.skipUnless(HAS_GPU, "Need GPU for this test") def test_int8_unpack(self): @torch.compile def f(x): @@ -231,7 +231,7 @@ def f(x): ) return unpacked * 2 - x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device="cuda") + x = torch.randint(0, 255, (2, 4096, 5504), dtype=torch.uint8, device=GPU_TYPE) triton_code = run_and_get_triton_code(f, x) # Make sure the 2 load uses simpified indexing rather than something like @@ -361,5 +361,5 @@ def test_print_Min_Max(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests("sympy") diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index 6ddec1dcdec440..d7e8e530648ee4 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -7,9 +7,8 @@ from torch._dynamo.test_minifier_common import MinifierTestBase from torch._inductor import config from torch.testing._internal.common_utils import IS_JETSON, IS_MACOS, TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import HAS_CUDA - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu class MinifierTests(MinifierTestBase): @@ -39,15 +38,15 @@ def test_after_aot_cpu_compile_error(self): def test_after_aot_cpu_accuracy_error(self): self._test_after_aot("cpu", "AccuracyError") - @requires_cuda + @requires_gpu @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "compile_error") - def test_after_aot_cuda_compile_error(self): - self._test_after_aot("cuda", "SyntaxError") + def test_after_aot_gpu_compile_error(self): + self._test_after_aot(GPU_TYPE, "SyntaxError") - @requires_cuda + @requires_gpu @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy") - def test_after_aot_cuda_accuracy_error(self): - self._test_after_aot("cuda", "AccuracyError") + def test_after_aot_gpu_accuracy_error(self): + self._test_after_aot(GPU_TYPE, "AccuracyError") @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "accuracy") def test_constant_in_graph(self): @@ -60,17 +59,19 @@ def inner(x): """ self._run_full_test(run_code, "aot", "AccuracyError", isolate=False) - @requires_cuda + @requires_gpu @patch.object(config, "joint_graph_constant_folding", False) def test_rmse_improves_over_atol(self): # From https://twitter.com/itsclivetime/status/1651135821045719041?s=20 run_code = """ @torch.compile() def inner(x): - return x - torch.tensor(655, dtype=torch.half, device='cuda') * 100 + return x - torch.tensor(655, dtype=torch.half, device='GPU_TYPE') * 100 -inner(torch.tensor(655 * 100, dtype=torch.half, device='cuda')) -""" +inner(torch.tensor(655 * 100, dtype=torch.half, device='GPU_TYPE')) +""".replace( + "GPU_TYPE", GPU_TYPE + ) # If we disable RMSE against fp64, this triggers accuracy error, # as the increased precision from torch.compile changes the result diff --git a/test/inductor/test_mmdecomp.py b/test/inductor/test_mmdecomp.py index 9319fed2b7941b..b67404eadffac9 100644 --- a/test/inductor/test_mmdecomp.py +++ b/test/inductor/test_mmdecomp.py @@ -9,14 +9,8 @@ from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import ( - IS_WINDOWS, - parametrize, - run_tests, - TEST_CUDA, -) -from torch.utils._triton import has_triton - +from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU default_atol = { torch.float16: 1e-3, @@ -96,10 +90,10 @@ def torch_baddbmm(add, b, c, alpha, beta): class TestDecomp(NNTestCase): - _do_cuda_memory_leak_check = True - _do_cuda_non_default_stream = True + _do_cuda_memory_leak_check = GPU_TYPE == "cuda" + _do_cuda_non_default_stream = GPU_TYPE == "cuda" - @unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton") + @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @parametrize("dtype", [torch.float, torch.bfloat16]) def test_simple_mm(self, device, dtype): fudge = 10 @@ -116,7 +110,7 @@ def test_simple_mm(self, device, dtype): run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol) run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol) - @unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton") + @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @parametrize( "dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float] ) @@ -141,7 +135,7 @@ def test_batched_mm(self, device, dtype, bs): torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol ) - @unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton") + @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @config.patch(coordinate_descent_tuning=True) def test_bmm_batch2_last_dim_size_is_one(self, device): fudge = 3 @@ -153,12 +147,12 @@ def test_bmm_batch2_last_dim_size_is_one(self, device): run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol) - @unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton") + @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) def test_some(self, device, dtype): # this Pytorch data type is not fully supported on cuda today # - unfortunately we can't skipIf because we don't see the actual parms in skipIf - if device.startswith("cuda") and dtype == torch.int: + if device.startswith(GPU_TYPE) and dtype == torch.int: return run_comp_nocomp( @@ -172,13 +166,13 @@ def test_some(self, device, dtype): init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device), ) - @unittest.skipIf(TEST_CUDA and not has_triton(), "CUDA tests require triton") + @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) @parametrize("bs", [1, 2, 4, 10]) def test_some_batched(self, device, dtype, bs): # this Pytorch data type is not fully supported on cuda today # - unfortunately we can't skipIf because we don't see the actual parms in skipIf - if device.startswith("cuda") and dtype == torch.int: + if device.startswith(GPU_TYPE) and dtype == torch.int: return run_comp_nocomp( @@ -193,7 +187,7 @@ def test_some_batched(self, device, dtype, bs): ) -device_types = ("cpu", "cuda") +device_types = ("cpu", GPU_TYPE) instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types) if __name__ == "__main__": diff --git a/test/inductor/test_smoke.py b/test/inductor/test_smoke.py index 309c007a17b3a8..b0960f0e9368be 100644 --- a/test/inductor/test_smoke.py +++ b/test/inductor/test_smoke.py @@ -7,7 +7,7 @@ from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU class MLP(torch.nn.Module): @@ -27,20 +27,20 @@ def _test_f(x): class SmokeTest(TestCase): - @unittest.skipIf(not HAS_CUDA, "Triton is not available") + @unittest.skipIf(not HAS_GPU, "Triton is not available") def test_mlp(self): torch._logging.set_logs( dynamo=logging.DEBUG, inductor=logging.DEBUG, aot=logging.DEBUG ) - mlp = torch.compile(MLP().cuda()) + mlp = torch.compile(MLP().to(GPU_TYPE)) for _ in range(3): - mlp(torch.randn(1, device="cuda")) + mlp(torch.randn(1, device=GPU_TYPE)) # set back to defaults torch._logging.set_logs() - @unittest.skipIf(not HAS_CUDA, "Triton is not available") + @unittest.skipIf(not HAS_GPU, "Triton is not available") def test_compile_decorator(self): @torch.compile def foo(x): @@ -51,8 +51,8 @@ def bar(x): return x * x for _ in range(3): - foo(torch.full((3, 4), 0.7, device="cuda")) - bar(torch.rand((2, 2), device="cuda")) + foo(torch.full((3, 4), 0.7, device=GPU_TYPE)) + bar(torch.rand((2, 2), device=GPU_TYPE)) def test_compile_invalid_options(self): with self.assertRaises(RuntimeError): @@ -62,6 +62,6 @@ def test_compile_invalid_options(self): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if IS_LINUX and torch.cuda.is_available(): - if torch.cuda.get_device_properties(0).major > 5: + if IS_LINUX and HAS_GPU: + if (not HAS_CUDA) or torch.cuda.get_device_properties(0).major <= 5: run_tests() diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index b0cc28205daa22..0005cf76c28cb9 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -1,15 +1,13 @@ # Owner(s): ["module: inductor"] -import unittest import torch from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +from torch.testing._internal.triton_utils import requires_gpu def patch(f): @@ -1205,12 +1203,12 @@ def fn(x, y): self.assertTrue(k not in {"x", "x1", "x2", "a", "axis", "keepdims"}) @patch - @requires_cuda + @requires_gpu def test_stack_normalization_axis_kwarg(self): def fn(x, y): return torch.stack([x, y], axis=1) - x, y = (torch.rand((4, 4), device="cuda") for _ in range(2)) + x, y = (torch.rand((4, 4), device=GPU_TYPE) for _ in range(2)) expected = fn(x, y) actual = torch.compile(fn)(x, y) @@ -1218,5 +1216,5 @@ def fn(x, y): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_GPU: run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ea891f6ab74334..54915479eb5c98 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2997,7 +2997,7 @@ def fn(a, b): check_lowp=True, ) - @expectedFailureXPU + @skipIfXpu def test_mm_mixed_dtype(self): def fn(a, b): return torch.mm(a, b) @@ -3011,7 +3011,7 @@ def fn(a, b): with self.assertRaisesRegex(RuntimeError, msg): fn(t1, t2) - @expectedFailureXPU + @skipIfXpu def test_linear_mixed_dtype(self): class Net(nn.Module): def __init__(self): @@ -6347,7 +6347,6 @@ def fn(a, b): (a, b), ) - @skipIfXpu def test_nll_loss_backward(self): def fn(a, b, c): return aten.nll_loss_backward( @@ -9596,6 +9595,7 @@ def _cases_resize_as_common(): tuple(reversed(range(len(y_size)))) ), torch.preserve_format + @skipIfXpu def test_resize_as(self): def fn(x, y, memory_format): return torch.ops.aten.resize_as(x, y, memory_format=memory_format) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index bd036810d4c147..07d48141a0e2d7 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -17,7 +17,7 @@ _check_has_dynamic_shape, GPU_TYPE, HAS_CPU, - HAS_CUDA, + HAS_GPU, ) if IS_WINDOWS and IS_CI: @@ -104,11 +104,13 @@ def run(*ex, **kwargs): # # Failed to find dynamic for loop variable (no kernels generated) # - "test_fft_real_input_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_fft_real_input_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), "test_fft_real_input_real_output_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_to_device_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_to_device_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), # # Failed to find dynamic for loop variable: # @@ -140,112 +142,124 @@ def run(*ex, **kwargs): # # Failed to find for loop/triton kernel: # - "test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_compar_dynamic_shapes": TestFailure(("cpu",)), - "test_const_int32_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_const_int32_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_conv2d_backward_channels_last_dynamic_shapes": TestFailure(("cpu",)), - "test_conv_backward_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_conv_backward_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_conv_functional_bn_fuse_dynamic_shapes": TestFailure(("cpu",), is_skip=True), "test_convolution2_dynamic_shapes": TestFailure(("cpu",)), "test_cumprod_zero_dim_dynamic_shapes": TestFailure(("cpu",)), "test_cumsum_dynamic_shapes": TestFailure(("cpu",)), "test_cumsum_no_mask_dynamic_shapes": TestFailure(("cpu",)), "test_cumsum_zero_dim_dynamic_shapes": TestFailure(("cpu",)), - "test_div8_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_embedding_bag_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_empty1_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_div8_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_embedding_bag_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_empty1_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_bucketize_dynamic_shapes": TestFailure("cpu"), "test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"), "test_bucketize_int_dynamic_shapes": TestFailure("cpu"), - "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)), "test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)), - "test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_max_pool2d6_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure( ("cpu", "cuda") ), "test_max_pool2d_with_indices_backward6_dynamic_shapes": TestFailure( - ("cpu", "cuda") + ("cpu", "cuda", "xpu") ), "test_misaligned_address_issue1_dynamic_shapes": TestFailure(("cpu",)), - "test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_new_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_new_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_new_ones_dynamic_shapes": TestFailure(("cpu",)), - "test_permute2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_pointwise_airy_ai_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_digamma_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_gammainc_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_gammaincc_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_i0e_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_i1e_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_ndtri_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_polygamma_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_psi_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_scaled_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_scaled_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_spherical_bessel_j0_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_zeta_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_chebyshev_polynomial_t_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_chebyshev_polynomial_u_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_chebyshev_polynomial_v_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_chebyshev_polynomial_w_dynamic_shapes": TestFailure(("cuda",)), + "test_permute2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_pointwise_airy_ai_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_digamma_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_gammainc_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_gammaincc_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_i0e_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_i1e_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_ndtri_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_polygamma_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_psi_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_scaled_modified_bessel_k0_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), + "test_pointwise_scaled_modified_bessel_k1_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), + "test_pointwise_spherical_bessel_j0_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_zeta_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_chebyshev_polynomial_t_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), + "test_pointwise_chebyshev_polynomial_u_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), + "test_pointwise_chebyshev_polynomial_v_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), + "test_pointwise_chebyshev_polynomial_w_dynamic_shapes": TestFailure( + ("cuda", "xpu") + ), "test_pointwise_shifted_chebyshev_polynomial_t_dynamic_shapes": TestFailure( - ("cuda",) + ("cuda", "xpu") ), "test_pointwise_shifted_chebyshev_polynomial_u_dynamic_shapes": TestFailure( - ("cuda",) + ("cuda", "xpu") ), "test_pointwise_shifted_chebyshev_polynomial_v_dynamic_shapes": TestFailure( - ("cuda",) + ("cuda", "xpu") ), "test_pointwise_shifted_chebyshev_polynomial_w_dynamic_shapes": TestFailure( - ("cuda",) + ("cuda", "xpu") ), - "test_pointwise_hermite_polynomial_h_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda",)), - "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda",)), + "test_pointwise_hermite_polynomial_h_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")), + "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")), "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)), - "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_single_elem_dynamic_shapes": TestFailure(("cpu",)), "test_single_elem_indirect_dynamic_shapes": TestFailure(("cpu",)), - "test_sort_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_sort_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_split_cumsum_dynamic_shapes": TestFailure(("cpu",)), "test_split_cumsum_low_prec_dynamic_shapes": TestFailure(("cpu",)), "test_split_cumprod_dynamic_shapes": TestFailure(("cpu",)), "test_split_cumprod_low_prec_dynamic_shapes": TestFailure(("cpu",)), - "test_split_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_topk_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_unbind_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_views5_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_view_detach_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_view_on_aliased_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_split_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_topk_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_unbind_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_views5_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_view_detach_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_view_on_aliased_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linear_float64_dynamic_shapes": TestFailure("cpu"), "test_adaptive_avg_pool_with_output_size_0_dynamic_shapes": TestFailure( - ("cpu", "cuda") + ("cpu", "cuda", "xpu") ), - "test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_custom_op_3_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_custom_op_3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( - ("cpu", "cuda") + ("cpu", "cuda", "xpu") ), "test_cat_uint8_dynamic_shapes": TestFailure( ("cpu",) @@ -253,70 +267,96 @@ def run(*ex, **kwargs): # # Tests not using 'common' or directly calling 'assertEqual': # - "test_arange5_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_cat_inplace_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_arange5_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), + "test_cat_inplace_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), "test_cat_of_loops_and_extern_kernel_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), # need to enable CL with dynamic shapes "test_scaled_dot_product_efficient_attention_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_dropout_deterministic_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_dropout_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_dropout_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), "test_dtype_mismatch_issue_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_forced_buffer_realize_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_tmp_not_defined_issue3_dynamic_shapes": TestFailure(("cpu",), is_skip=True), - "test_gather2_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_inplace_add_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_gather2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), + "test_inplace_add_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), "test_inplace_mixed_dtype_ops_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True - ), - "test_input_mutation1_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_input_mutation2_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_input_mutation3_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_input_mutation4_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_kernel_names_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_lerp_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_input_mutation1_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_input_mutation2_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_input_mutation3_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_input_mutation4_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_kernel_names_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_lerp_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), "test_linear_buffer_reuse_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_list_clearing_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_dropout2_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_dropout3_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_list_clearing_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_dropout2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), + "test_dropout3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), "test_masked_fill_promotion_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_min_max_reduction_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_min_max_reduction_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), "test_multi_gpu_recompile_on_index_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_output_strides_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_output_strides_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_pow3_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), + "test_pow3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu"), is_skip=True), "test_profiler_mark_wrapper_call_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_rand_like_deterministic_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_slice_mutation2_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_strided_inputs_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_slice_mutation2_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), - "test_strided_inputs_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), "test_transposed_propagates_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_require_stride_expanded_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True + ), + "test_unspec_inputs_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu"), is_skip=True ), - "test_unspec_inputs_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True), "test_zero_dim_reductions_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), "test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True), @@ -325,11 +365,11 @@ def run(*ex, **kwargs): # "test_cudnn_rnn_dynamic_shapes": TestFailure(("cuda",)), # test_roi_align uses torchvision, which doesn't work with dynamic shapes - "test_roi_align_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_roi_align_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_aliased_buffer_reuse_dynamic_shapes": TestFailure(("cpu",)), # The input of this case has only 1 elements "test_mutations_loop_fusion_dynamic_shapes": TestFailure( - ("cpu", "cuda"), is_skip=True + ("cpu", "cuda", "xpu"), is_skip=True ), # Refinement means we don't actually generate dynamic shapes (but only on # cpu apparently?!) @@ -375,7 +415,7 @@ def common(self: TestCase, model, example_inputs, kwargs=None, **_rest): ) -if HAS_CUDA and not TEST_WITH_ASAN: +if HAS_GPU and not TEST_WITH_ASAN: class DynamicShapesCodegenGPUTests(TestCase): maxDiff = None @@ -401,5 +441,5 @@ def common(self: TestCase, model, example_inputs, kwargs=None, **_rest): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 9fcb4dfdfc3a67..8d05625002a4e9 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -21,7 +21,7 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, - onlyCUDA, + onlyOn, ) from torch.testing._internal.common_utils import ( IS_ARM64, @@ -32,7 +32,7 @@ TEST_WITH_ASAN, TEST_WITH_ROCM, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU if IS_WINDOWS and IS_CI: sys.stderr.write( @@ -59,8 +59,10 @@ test_failures = { "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), # calling div on only symint args - "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_conv_inference_heuristics_dynamic_shapes": TestFailure("cuda"), + "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu") + ), + "test_conv_inference_heuristics_dynamic_shapes": TestFailure(("cuda", "xpu")), } if TEST_WITH_ROCM: @@ -96,7 +98,7 @@ class DynamicShapesCpuTests(TestCase): copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_failures) -if HAS_CUDA and not TEST_WITH_ASAN: +if HAS_GPU and not TEST_WITH_ASAN: class DynamicShapesGPUTests(TestCase): common = check_model_gpu @@ -440,7 +442,7 @@ def f(x): return torch.ops.aten.cat.default([g, g, g2]) cf = torch.compile(fullgraph=True)(f) - arg = torch.tensor([4, 6], device="cuda") + arg = torch.tensor([4, 6], device=GPU_TYPE) self.assertEqual(f(arg), cf(arg)) @torch._dynamo.config.patch( @@ -534,8 +536,7 @@ def fn(x): res1 = opt(x1) self.assertEqual(ref1, res1) - # Need to comment: is xpu need this? if yes we may need to add onlyGPU - @onlyCUDA + @onlyOn(GPU_TYPE) def test_pad_dynamic(self, device): def get_same_padding(x: int, k: int, s: int, d: int): return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) @@ -857,11 +858,11 @@ def f(x): f(torch.tensor([5], device=device)) -instantiate_device_type_tests(TestInductorDynamic, globals()) +instantiate_device_type_tests(TestInductorDynamic, globals(), allow_xpu=True) if __name__ == "__main__": from torch._inductor.test_case import run_tests # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 - if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ASAN: + if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index b54869082a889f..6c16cdab085282 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -16,20 +16,27 @@ from torch._inductor import metrics from torch._inductor.utils import run_and_get_code from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU -if HAS_CUDA: +if HAS_GPU: import triton from triton import language as tl if not TEST_WITH_ROCM: - from triton.language.extra.cuda.libdevice import ( - fast_dividef, - fast_dividef as my_fast_dividef, - ) + if HAS_CUDA: + from triton.language.extra.cuda.libdevice import ( + fast_dividef, + fast_dividef as my_fast_dividef, + ) + elif HAS_XPU: + from triton.language.extra.intel.libdevice import ( + fast_dividef, + fast_dividef as my_fast_dividef, + ) # Define shared triton constants here. CONSTANT_C: tl.constexpr = 4 @@ -38,7 +45,7 @@ class KernelTests(torch._inductor.test_case.TestCase): - @requires_cuda + @requires_gpu def test_triton_kernel_with_kernel_param(self): @triton.jit def pass_kernel(kernel): @@ -49,19 +56,19 @@ def f(x): grid = (x.numel(),) pass_kernel[grid](kernel=x) - t1 = torch.rand(5, device="cuda") + t1 = torch.rand(5, device=GPU_TYPE) f(t1) # No need to assert anything, the goal is to make sure dynamo does # not crash - @requires_cuda + @requires_gpu def test_triton_kernel_higher_order_func(self): from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table add_kernel_id = kernel_side_table.add_kernel(add_kernel) - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") + t1 = torch.rand(5, device=GPU_TYPE) + t2 = torch.rand(5, device=GPU_TYPE) torch_add = t1 + t2 @@ -103,7 +110,7 @@ def test_triton_kernel_higher_order_func(self): # Make sure it is NOT modified self.assertEqual(output, torch.zeros_like(t1)) - @requires_cuda + @requires_gpu def test_triton_kernel_functionalize(self): from functorch import make_fx from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table @@ -130,8 +137,8 @@ def f(x, output): ) return out["out_ptr"] - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") + t1 = torch.rand(5, device=GPU_TYPE) + t2 = torch.rand(5, device=GPU_TYPE) with FunctionalTensorMode(): gm = make_fx(PythonFunctionalizeAPI().functionalize(f))(t1, t2) # Make sure t2 was not modified @@ -156,7 +163,7 @@ def forward(self, x_1, output_1): return getitem_1""", ) - @requires_cuda + @requires_gpu def test_triton_kernel_mutation_type(self): from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table from torch._subclasses.fake_tensor import FakeTensorMode @@ -166,7 +173,7 @@ def test_triton_kernel_mutation_type(self): ) def prep(): - x = torch.ones(4, device="cuda", requires_grad=True) + x = torch.ones(4, device=GPU_TYPE, requires_grad=True) with FunctionalTensorMode(): x_func = FunctionalTensor.to_functional(x) self.assertTrue(torch._is_functional_tensor(x_func.elem)) @@ -224,7 +231,7 @@ def prep(): torch._functionalize_are_all_mutations_hidden_from_autograd(x_func.elem) ) - @requires_cuda + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_with_views(self, dynamic, backend): @@ -242,7 +249,7 @@ def call_triton_return_view(x: torch.Tensor): mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) return output.view(4, 4) - t = torch.rand(4, 4, device="cuda") + t = torch.rand(4, 4, device=GPU_TYPE) t_view = t.view(16) compiled_func = torch.compile( @@ -257,7 +264,7 @@ def call_triton_return_view(x: torch.Tensor): self.assertEqual(2 * t_view, compiled_func(t).view(16)) self.assertEqual(2 * t, compiled_func(t)) - @requires_cuda + @requires_gpu @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_with_grad_option(self, grad_fn, backend): @@ -269,11 +276,11 @@ def call_triton(x: torch.Tensor): mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) return output - t = torch.rand(5, device="cuda") + t = torch.rand(5, device=GPU_TYPE) compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True) self.assertEqual(2 * t, compiled_func(t)) - @requires_cuda + @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_inner_triton_function(self, backend): def f(x: torch.Tensor): @@ -298,13 +305,13 @@ def pow2_kernel( pow2_kernel[grid](x, output, n_elements, BLOCK_SIZE=16) return output - t = torch.rand(5, device="cuda") + t = torch.rand(5, device=GPU_TYPE) compiled_func = torch.compile(f, backend=backend, fullgraph=True) # TODO(oulgen): NYI - Support this # self.assertEqual(t * t, compiled_func(t)) - @requires_cuda + @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) @patch.object(torch._inductor.config, "implicit_fallbacks", False) @@ -322,8 +329,8 @@ def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): return output, tmp - t1 = torch.rand(5, device="cuda", requires_grad=grad) - t2 = torch.rand(5, device="cuda", requires_grad=grad) + t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) + t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) o1 = torch.zeros_like(t1, requires_grad=grad) torch_add = call_triton(t1, t2, o1) @@ -346,7 +353,7 @@ def call_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor): else: self.assertTrue("return (buf0, )" in codes[0]) - @requires_cuda + @requires_gpu def test_triton_kernel_caching(self): from torch._inductor.utils import run_and_get_code @@ -368,14 +375,14 @@ def call_triton_add( x = add_in_loop(x, y) return x - t1 = torch.ones(5, device="cuda") - t2 = torch.ones(5, device="cuda") + t1 = torch.ones(5, device=GPU_TYPE) + t2 = torch.ones(5, device=GPU_TYPE) test, (code,) = run_and_get_code(torch.compile(call_triton_add), t1, t2) - self.assertEqual(test, 5 * torch.ones(5, device="cuda")) + self.assertEqual(test, 5 * torch.ones(5, device=GPU_TYPE)) self.assertTrue("add_kernel_autotuned_1.run" not in code) - @requires_cuda + @requires_gpu def test_triton_kernel_caching_duplicate(self): from torch._inductor.utils import run_and_get_code @@ -418,13 +425,13 @@ def call_triton(x: torch.Tensor): D.pass_kernel[grid](x, output2, n_elements, BLOCK_SIZE=16) return output1 + output2 - t = torch.ones(5, device="cuda") + t = torch.ones(5, device=GPU_TYPE) test, (code,) = run_and_get_code(torch.compile(call_triton), t) # Make sure we emitted two kernels here self.assertTrue("pass_kernel_0.run" in code) self.assertTrue("pass_kernel_1.run" in code) - @requires_cuda + @requires_gpu def test_triton_kernel_various_args(self): @triton.autotune( configs=[triton.Config({"BLOCK_SIZE": 128})], @@ -456,11 +463,11 @@ def call_triton(output): ) return output - output = torch.randn(5, device="cuda") + output = torch.randn(5, device=GPU_TYPE) # Make sure this does not crash call_triton(output) - @requires_cuda + @requires_gpu @skipIfRocm def test_triton_kernel_dependancies(self): def call_triton( @@ -476,13 +483,13 @@ def call_triton( output3 = torch.add(output2, 1) return output3 - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") + t1 = torch.rand(5, device=GPU_TYPE) + t2 = torch.rand(5, device=GPU_TYPE) torch_result = call_triton(t1, t2) compiled_result = torch.compile(call_triton)(t1, t2) self.assertEqual(torch_result, compiled_result) - @requires_cuda + @requires_gpu def test_triton_kernel_reinplace_inplaceable_pass(self): def call_triton( x: torch.Tensor, @@ -495,13 +502,13 @@ def call_triton( add_kernel_autotuned[grid](output, x, output, n_elements) return output - t1 = torch.rand(5, device="cuda") - t2 = torch.rand(5, device="cuda") + t1 = torch.rand(5, device=GPU_TYPE) + t2 = torch.rand(5, device=GPU_TYPE) torch_result = call_triton(t1, t2) compiled_result = torch.compile(call_triton)(t1, t2) self.assertEqual(torch_result, compiled_result) - @requires_cuda + @requires_gpu @common_utils.parametrize("grad", [False, True]) def test_triton_kernel_multi_kernel(self, grad): @triton.jit @@ -560,16 +567,16 @@ def call_triton( return (output, outputi) t1 = torch.tensor( - [-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad + [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad ) t2 = torch.tensor( - [-2.0, -1.0, 0.0, 1.0, 2.0], device="cuda", requires_grad=grad + [-2.0, -1.0, 0.0, 1.0, 2.0], device=GPU_TYPE, requires_grad=grad ) float_result = 2 * t1 + 2 * t2 float_result = float_result.where(float_result >= 0, 0.0) - t1i = torch.randint(-2, 2, (5,), device="cuda") - t2i = torch.randint(-2, 2, (5,), device="cuda") + t1i = torch.randint(-2, 2, (5,), device=GPU_TYPE) + t2i = torch.randint(-2, 2, (5,), device=GPU_TYPE) o = torch.zeros_like(t1, requires_grad=grad) oi = torch.zeros_like(t1i) int_result = 2 * t1i + 2 * t2i @@ -578,7 +585,8 @@ def call_triton( self.assertEqual(float_result, result) self.assertEqual(int_result, resulti) - @requires_cuda + @requires_gpu + @skipIfXpu @skipIfRocm def test_triton_kernel_constants(self): @triton.jit @@ -620,7 +628,7 @@ def call_triton( CONSTANT_C = 10 assert CONSTANT_C != prev_c - t = torch.randn(5, device="cuda") + t = torch.randn(5, device=GPU_TYPE) torch_result = call_triton(t) compiled_result = torch.compile(call_triton)(t) @@ -629,7 +637,7 @@ def call_triton( # reset back CONSTANT_C = prev_c - @requires_cuda + @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -651,8 +659,8 @@ def grid_fn(meta): add_kernel_autotuned[grid](x, y, output, n_elements) return output - t1 = torch.rand(256, device="cuda", requires_grad=grad) - t2 = torch.rand(256, device="cuda", requires_grad=grad) + t1 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) + t2 = torch.rand(256, device=GPU_TYPE, requires_grad=grad) output = torch.zeros_like(t1, requires_grad=grad) torch_add = call_triton(t1, t2, output) @@ -663,7 +671,7 @@ def grid_fn(meta): output2 = torch.zeros_like(t1, requires_grad=grad) self.assertEqual(compiled_func(t1, t2, output2), torch_add) - @requires_cuda + @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -692,8 +700,8 @@ def grid_fn(meta): add_kernel_2d_autotuned[grid](x, y, output, x_elements, y_elements) return output - t1 = torch.rand((512, 256), device="cuda", requires_grad=grad) - t2 = torch.rand((512, 256), device="cuda", requires_grad=grad) + t1 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) + t2 = torch.rand((512, 256), device=GPU_TYPE, requires_grad=grad) output = torch.zeros_like(t1, requires_grad=grad) torch_result = call_triton(t1, t2, output) @@ -703,7 +711,7 @@ def grid_fn(meta): output2 = torch.zeros_like(t1, requires_grad=grad) self.assertEqual(compiled_func(t1, t2, output2), torch_result) - @requires_cuda + @requires_gpu @common_utils.parametrize("grad", [False, True]) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -736,8 +744,8 @@ def grid_fn(meta): return output - t1 = torch.rand(5, device="cuda", requires_grad=grad) - t2 = torch.rand(5, device="cuda", requires_grad=grad) + t1 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) + t2 = torch.rand(5, device=GPU_TYPE, requires_grad=grad) o1 = torch.zeros_like(t1, requires_grad=grad) torch_add = t1 + t2 @@ -765,7 +773,7 @@ def grid_fn(meta): o6 = torch.zeros_like(t1, requires_grad=grad) self.assertEqual(compiled_func(t1, t2, o6, 2, 200), torch_add) - @requires_cuda + @requires_gpu def test_triton_kernel_mutation_not_mark_dirty(self): @torch.compile def f(x): @@ -773,7 +781,7 @@ def f(x): add_kernel[(n_elements,)](x, x, x, n_elements, 16) return x - x = torch.randn(5, device="cuda", requires_grad=True) + x = torch.randn(5, device=GPU_TYPE, requires_grad=True) x_cloned = x.clone() out = x_cloned.sin() f(x_cloned) @@ -815,7 +823,7 @@ def f(x): num_bufs_reused = code.count("# reuse") self.assertEqual(num_bufs_reused, 3) - @requires_cuda + @requires_gpu def test_triton_kernel_matmul_tracking(self): @triton.jit def ones_kernel(x_ptr, n_elements, BLOCK_SIZE: "tl.constexpr"): @@ -832,12 +840,12 @@ def f(x): ones_kernel[(4,)](out, 16, BLOCK_SIZE=16) return torch.mm(out, x) + 10 - x = torch.randn(4, 4, device="cuda") + x = torch.randn(4, 4, device=GPU_TYPE) torch_out = f(x) - python_out = torch.mm(torch.ones(4, 4, device="cuda"), x) + 10 + python_out = torch.mm(torch.ones(4, 4, device=GPU_TYPE), x) + 10 self.assertEqual(torch_out, python_out) - @requires_cuda + @requires_gpu def test_triton_kernel_strided_input(self): def f(inp): # left has strides [256, 1] @@ -855,13 +863,13 @@ def f(inp): ) return out - inp = torch.randn(64, 256, device="cuda") + inp = torch.randn(64, 256, device=GPU_TYPE) eager_out = f(inp) compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu def test_triton_kernel_strided_input_nonzero_offset(self): def f(inp): # right has strides [256, 1] and storage offset 128 @@ -879,13 +887,13 @@ def f(inp): ) return out - inp = torch.randn(64, 256, device="cuda") + inp = torch.randn(64, 256, device=GPU_TYPE) eager_out = f(inp) compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu def test_triton_kernel_slice_and_view_input(self): def f(inp): # left has strides [256, 1] @@ -907,13 +915,13 @@ def f(inp): ) return out + left - inp = torch.randn(64, 256, device="cuda") + inp = torch.randn(64, 256, device=GPU_TYPE) eager_out = f(inp) compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu def test_triton_kernel_fallback(self): def f(x, y): out = torch.zeros_like(x) @@ -924,13 +932,13 @@ def f(x, y): add_kernel[(4,)](x, torch.sort(y).values, out, 4, 16) return out, out2 - x = torch.randn(4, 4, device="cuda") - y = torch.randn(4, 4, device="cuda") + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) eager_out = f(x, y) compiled_out = torch.compile(f)(x, y) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu def test_triton_kernel_out_of_order(self): @triton.jit def add_kernel( @@ -955,13 +963,13 @@ def f(x, y): add_kernel[(n_elements,)](x, y, 4, out, n_elements) return out - x = torch.randn(4, device="cuda") - y = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) + y = torch.randn(4, device=GPU_TYPE) eager_out = f(x, y) compiled_out = torch.compile(f)(x, y) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) @torch._dynamo.config.patch(capture_scalar_outputs=True) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -989,12 +997,12 @@ def f(x): square[grid](x, output, n_elements, BLOCK_SIZE=16) return output - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) eager_out = f(x) compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) def test_triton_kernel_equal_to_1_arg(self, dynamic): @triton.jit @@ -1022,8 +1030,8 @@ def f(x, y): ) return out - x = torch.randn(2, device="cuda") - y = torch.randn(2, device="cuda") + x = torch.randn(2, device=GPU_TYPE) + y = torch.randn(2, device=GPU_TYPE) eager_out = f(x, y) compiled_out, sources = run_and_get_code( torch.compile(f, dynamic=dynamic), x, y @@ -1037,7 +1045,7 @@ def f(x, y): self.assertTrue("equal_to_1=(3,)" in sources[0]) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) def test_triton_kernel_equal_to_1_float_arg(self, dynamic): def f(x, y): @@ -1054,8 +1062,8 @@ def f(x, y): ) return out - x = torch.randn(2, device="cuda") - y = torch.randn(2, device="cuda") + x = torch.randn(2, device=GPU_TYPE) + y = torch.randn(2, device=GPU_TYPE) eager_out = f(x, y) compiled_out, sources = run_and_get_code( torch.compile(f, dynamic=dynamic), x, y @@ -1066,7 +1074,7 @@ def f(x, y): self.assertTrue("equal_to_1=()" in sources[0]) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @skipIfRocm def test_triton_kernel_with_imported_symbol(self): @triton.jit @@ -1092,13 +1100,13 @@ def f(x): ) return out - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) eager_out = f(x) compiled_out = torch.compile(f)(x) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @skipIfRocm def test_triton_kernel_with_imported_symbol_with_custom_name(self): @triton.jit @@ -1124,13 +1132,13 @@ def f(x): ) return out - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) eager_out = f(x) compiled_out = torch.compile(f)(x) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @common_utils.parametrize("size", [4, 16]) @common_utils.parametrize("dynamic", [False, True]) def test_triton_kernel_different_shapes(self, size, dynamic): @@ -1149,10 +1157,10 @@ def f(x, y, xx, yy): return output_1, output_2 - x = torch.rand(size, device="cuda") - y = torch.rand(size, device="cuda") - xx = torch.rand(size, size, device="cuda") - yy = torch.rand(size, size, device="cuda") + x = torch.rand(size, device=GPU_TYPE) + y = torch.rand(size, device=GPU_TYPE) + xx = torch.rand(size, size, device=GPU_TYPE) + yy = torch.rand(size, size, device=GPU_TYPE) args = [x, y, xx, yy] eager_out = f(*args) @@ -1171,7 +1179,7 @@ def f(x, y, xx, yy): self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu def test_triton_kernel_reset_to_zero(self): @triton.autotune( configs=[ @@ -1206,12 +1214,12 @@ def f(x, y): add_kernel_autotuned_reset[grid](x, y, output, n_elements) return output - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) msg = "Only configs and keys are supported for triton.autotune" with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): f(x, x) - @requires_cuda + @requires_gpu @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_triton_dtype(self, dynamic, backend): @@ -1242,8 +1250,8 @@ def f(x, y, dtype_torch, dtype_triton): ) return output - x = torch.randn(4, device="cuda") - y = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) + y = torch.randn(4, device=GPU_TYPE) args_list = ( [x, y, torch.float32, tl.float32], [x, y, torch.bfloat16, tl.bfloat16], @@ -1255,7 +1263,7 @@ def f(x, y, dtype_torch, dtype_triton): )(*args) self.assertEqual(compiled_out, eager_out) - @requires_cuda + @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_special_kwargs_with_autotune(self, backend): @triton.autotune( @@ -1297,10 +1305,10 @@ def f(x, y): ) return output - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) f(x, x) - @requires_cuda + @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_num_ctas(self, backend): @triton.jit @@ -1313,10 +1321,10 @@ def f(x): kernel.run(x, num_ctas=1, grid=(1,), warmup=False) return x - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) f(x) - @requires_cuda + @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_special_kwargs_without_autotune(self, backend): @triton.jit @@ -1352,12 +1360,12 @@ def f(x, y): ) return output - x = torch.randn(4, device="cuda") + x = torch.randn(4, device=GPU_TYPE) f(x, x) def make_mutation_test(fn): - @requires_cuda + @requires_gpu def test_fn(self): from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors @@ -1372,7 +1380,7 @@ def test_fn(self): # Triton codegen suffers from scoping issues. # Define helpers here -if HAS_CUDA: +if HAS_GPU: @triton.jit def helper_id(p): @@ -2044,7 +2052,7 @@ def fwd_kernel( ) -if HAS_CUDA: +if HAS_GPU: t = torch.randn(4) tt = torch.randn(4, 1) tests = [ diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 5ab4901155bf7d..ec45116f938649 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -786,13 +786,13 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo def get_desired_device_type_test_bases( - except_for=None, only_for=None, include_lazy=False, allow_mps=False + except_for=None, only_for=None, include_lazy=False, allow_mps=False, allow_xpu=False ): # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy` test_bases = device_type_test_bases.copy() if allow_mps and TEST_MPS and MPSTestBase not in test_bases: test_bases.append(MPSTestBase) - if only_for == "xpu" and TEST_XPU and XPUTestBase not in test_bases: + if allow_xpu and TEST_XPU and XPUTestBase not in test_bases: test_bases.append(XPUTestBase) if TEST_HPU and HPUTestBase not in test_bases: test_bases.append(HPUTestBase) @@ -853,6 +853,7 @@ def split_if_not_empty(x: str): # device-specific tests (NB: this supports additional @parametrize usage). # # See note "Writing Test Templates" +# TODO: remove "allow_xpu" option after Interl GPU support all test case instantiate by this function. def instantiate_device_type_tests( generic_test_class, scope, @@ -860,6 +861,7 @@ def instantiate_device_type_tests( only_for=None, include_lazy=False, allow_mps=False, + allow_xpu=False, ): # Removes the generic test class from its enclosing scope so its tests # are not discoverable. @@ -883,7 +885,7 @@ def instantiate_device_type_tests( # Creates device-specific test cases for base in get_desired_device_type_test_bases( - except_for, only_for, include_lazy, allow_mps + except_for, only_for, include_lazy, allow_mps, allow_xpu ): class_name = generic_test_class.__name__ + base.device_type.upper() @@ -1250,6 +1252,9 @@ def _has_sufficient_memory(device, size): if device == "xla": raise unittest.SkipTest("TODO: Memory availability checks for XLA?") + if device == "xpu": + raise unittest.SkipTest("TODO: Memory availability checks for Intel GPU?") + if device != "cpu": raise unittest.SkipTest("Unknown device type") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2097e25bdaa892..bb5f3fa8e33084 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -230,7 +230,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'meta', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', torch._C._get_privateuse1_backend_name()) check_names = ['orin', 'concord', 'galen', 'xavier', 'nano', 'jetson', 'tegra'] IS_JETSON = any(name in platform.platform() for name in check_names) diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 491e2cb98cb9e5..d00deb480f5eb8 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -727,6 +727,9 @@ class OpInfo: # dtypes this function is expected to work with on ROCM dtypesIfROCM: _dispatch_dtypes = None + # dtypes this function is expected to work with on XPU + dtypesIfXPU: _dispatch_dtypes = None + # backward dtypes this function is expected to work with backward_dtypes: _dispatch_dtypes = None @@ -891,7 +894,12 @@ def __post_init__(self): assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!" - dtypes_args = (self.dtypes, self.dtypesIfCUDA, self.dtypesIfROCM) + dtypes_args = ( + self.dtypes, + self.dtypesIfCUDA, + self.dtypesIfROCM, + self.dtypesIfXPU, + ) # Validates the dtypes are generated from the dispatch-related functions for dtype_list in dtypes_args: @@ -960,6 +968,9 @@ def __post_init__(self): if self.dtypesIfROCM is not None else self.dtypesIfCUDA ) + self.dtypesIfXPU = ( + set(self.dtypesIfXPU) if self.dtypesIfXPU is not None else self.dtypesIfCUDA + ) # NOTE: if the op is unspecified it is assumed to be under the torch namespace if not self.op: @@ -1346,6 +1357,8 @@ def supported_dtypes(self, device_type): device_type = torch.device(device_type).type if device_type == "cuda": return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA + if device_type == "xpu": + return self.dtypesIfXPU return self.dtypes def supported_backward_dtypes(self, device_type): @@ -2631,6 +2644,7 @@ def __init__( dtypes=floating_types(), dtypesIfCUDA=None, dtypesIfROCM=None, + dtypesIfXPU=None, sample_inputs_func=None, **kwargs, ): @@ -2639,6 +2653,7 @@ def __init__( dtypes=dtypes, dtypesIfCUDA=dtypesIfCUDA, dtypesIfROCM=dtypesIfROCM, + dtypesIfXPU=dtypesIfXPU, sample_inputs_func=sample_inputs_func, **kwargs, ) diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 301c3cd4723e34..ab54e8b8bec99a 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -2,10 +2,11 @@ import unittest -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU from torch.utils._triton import has_triton requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu") if has_triton(): import triton From e4d8aa4d2496e3e7d32abe37835c1444a3cad3b3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 15 Jun 2024 20:19:42 -0700 Subject: [PATCH 067/171] [torchbench] Enable some models with inline_inbuilt_nn_modules (#128315) For all models, graph breaks/recompiles reduce. For drq, it increases and this is a legit one. Co-authored-by: Laith Sakka Pull Request resolved: https://github.com/pytorch/pytorch/pull/128315 Approved by: https://github.com/jansel --- .../aot_eager_torchbench_inference.csv | 6 ++--- .../aot_eager_torchbench_training.csv | 10 +++---- ...inductor_torchbench_freezing_inference.csv | 6 ++--- .../cpu_inductor_torchbench_inference.csv | 6 ++--- .../cu124/aot_eager_torchbench_training.csv | 2 +- .../dynamic_aot_eager_torchbench_training.csv | 2 +- .../dynamic_inductor_torchbench_training.csv | 2 +- .../dynamo_eager_torchbench_training.csv | 2 +- .../cu124/inductor_torchbench_training.csv | 2 +- ...dynamic_aot_eager_torchbench_inference.csv | 6 ++--- .../dynamic_aot_eager_torchbench_training.csv | 10 +++---- ...amic_cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_inductor_torchbench_inference.csv | 6 ++--- .../dynamic_inductor_torchbench_training.csv | 10 +++---- .../dynamo_eager_torchbench_inference.csv | 6 ++--- .../dynamo_eager_torchbench_training.csv | 10 +++---- .../inductor_torchbench_inference.csv | 6 ++--- .../inductor_torchbench_training.csv | 10 +++---- benchmarks/dynamo/common.py | 27 ++++++++++++------- benchmarks/dynamo/torchbench.py | 13 +++++++++ 20 files changed, 83 insertions(+), 61 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 9863aa7da6a252..43e53a25120acf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -14,7 +14,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,pass,12 +DALLE2_pytorch,pass,6 @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,43 +hf_BigBird,pass,13 @@ -374,7 +374,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,17 +vision_maskrcnn,pass,16 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index 4055eda462c5b4..7be1b635392a77 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -30,7 +30,7 @@ alexnet,pass,6 -basic_gnn_edgecnn,pass,22 +basic_gnn_edgecnn,pass,20 @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,49 +hf_BigBird,pass,19 @@ -114,7 +114,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,26 +hf_Reformer,pass,25 @@ -282,7 +282,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,34 +vision_maskrcnn,pass,33 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv index 73c46e578eec2b..c577edbb8aa46d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -86,7 +86,7 @@ detectron2_maskrcnn_r_101_c4,pass,57 -detectron2_maskrcnn_r_101_fpn,fail_accuracy,64 +detectron2_maskrcnn_r_101_fpn,fail_accuracy,63 @@ -94,7 +94,7 @@ detectron2_maskrcnn_r_50_c4,fail_accuracy,57 -detectron2_maskrcnn_r_50_fpn,pass,64 +detectron2_maskrcnn_r_50_fpn,pass,63 @@ -334,7 +334,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,28 +vision_maskrcnn,pass,27 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index f2aafef9db9fa3..85a1e3b0751ec9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -86,7 +86,7 @@ detectron2_maskrcnn_r_101_c4,fail_accuracy,57 -detectron2_maskrcnn_r_101_fpn,pass,64 +detectron2_maskrcnn_r_101_fpn,pass,63 @@ -94,7 +94,7 @@ detectron2_maskrcnn_r_50_c4,pass,57 -detectron2_maskrcnn_r_50_fpn,pass,64 +detectron2_maskrcnn_r_50_fpn,pass,63 @@ -334,7 +334,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,28 +vision_maskrcnn,pass,27 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv index 5131c2e9ade4be..e97c18a5de8a2c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_eager_torchbench_training.csv @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv index 1e1a4be4149e86..30bf8894129add 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_aot_eager_torchbench_training.csv @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv index 5d45a95c8f19b2..392f1c223be4e9 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamic_inductor_torchbench_training.csv @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv index cfc52442664401..74692ec745cfc2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/dynamo_eager_torchbench_training.csv @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv index a62389fad7dadf..75f5164f3def46 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/inductor_torchbench_training.csv @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 3aecea06b53009..9235a3463b6c80 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -14,7 +14,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,pass,12 +DALLE2_pytorch,pass,6 @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,43 +hf_BigBird,pass,13 @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,17 +vision_maskrcnn,pass,16 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index a9a7e396c0f3bb..0c29b9579aab1e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -30,7 +30,7 @@ alexnet,pass,6 -basic_gnn_edgecnn,pass,22 +basic_gnn_edgecnn,pass,20 @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,49 +hf_BigBird,pass,19 @@ -114,7 +114,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,26 +hf_Reformer,pass,25 @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,34 +vision_maskrcnn,pass,33 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index da37b6c9c9e809..d6487b0ce21a40 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -294,7 +294,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,28 +vision_maskrcnn,pass,27 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index c167ea680d2ca6..eb141a99c6de26 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -14,7 +14,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,pass,12 +DALLE2_pytorch,pass,6 @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,43 +hf_BigBird,fail_accuracy,13 @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,17 +vision_maskrcnn,pass,16 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index c25fa947133749..d27bae4879e7f2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -30,7 +30,7 @@ alexnet,pass,6 -basic_gnn_edgecnn,pass,22 +basic_gnn_edgecnn,pass,20 @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,49 +hf_BigBird,pass,19 @@ -114,7 +114,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,26 +hf_Reformer,pass,25 @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,34 +vision_maskrcnn,pass,33 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 9863aa7da6a252..43e53a25120acf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -14,7 +14,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,pass,12 +DALLE2_pytorch,pass,6 @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,43 +hf_BigBird,pass,13 @@ -374,7 +374,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,17 +vision_maskrcnn,pass,16 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 4055eda462c5b4..7be1b635392a77 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -30,7 +30,7 @@ alexnet,pass,6 -basic_gnn_edgecnn,pass,22 +basic_gnn_edgecnn,pass,20 @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,49 +hf_BigBird,pass,19 @@ -114,7 +114,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,26 +hf_Reformer,pass,25 @@ -282,7 +282,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,34 +vision_maskrcnn,pass,33 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 74549205d747a1..78a92fb526f4a3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -14,7 +14,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,pass,12 +DALLE2_pytorch,pass,6 @@ -150,7 +150,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_accuracy,43 +hf_BigBird,fail_accuracy,13 @@ -374,7 +374,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,17 +vision_maskrcnn,pass,16 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 4055eda462c5b4..7be1b635392a77 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -30,7 +30,7 @@ alexnet,pass,6 -basic_gnn_edgecnn,pass,22 +basic_gnn_edgecnn,pass,20 @@ -66,7 +66,7 @@ dlrm,pass,6 -drq,pass,6 +drq,pass,7 @@ -98,7 +98,7 @@ hf_Bert_large,pass,6 -hf_BigBird,pass,49 +hf_BigBird,pass,19 @@ -114,7 +114,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,26 +hf_Reformer,pass,25 @@ -282,7 +282,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,34 +vision_maskrcnn,pass,33 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 154651d4fbb736..3d4ff7199f71f7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -2214,6 +2214,10 @@ def skip_models_due_to_control_flow(self): def guard_on_nn_module_models(self): return set() + @property + def inline_inbuilt_nn_modules_models(self): + return set() + def get_tolerance_and_cosine_flag(self, is_training, current_device, name): raise NotImplementedError @@ -4218,16 +4222,21 @@ def detect_and_mark_batch(t): if name in runner.guard_on_nn_module_models: guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True) + inline_ctx = contextlib.nullcontext() + if name in runner.inline_inbuilt_nn_modules_models: + inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) + with guard_ctx: - runner.run_one_model( - name, - model, - example_inputs, - optimize_ctx, - experiment, - explain=args.explain, - tag=args.tag, - ) + with inline_ctx: + runner.run_one_model( + name, + model, + example_inputs, + optimize_ctx, + experiment, + explain=args.explain, + tag=args.tag, + ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" output_csv( diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index d7877c5a3fac4d..61175b46191794 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -217,6 +217,19 @@ def guard_on_nn_module_models(self): "vision_maskrcnn", } + @property + def inline_inbuilt_nn_modules_models(self): + return { + "basic_gnn_edgecnn", + "drq", + "hf_Reformer", + "DALLE2_pytorch", + "hf_BigBird", + "detectron2_maskrcnn_r_50_fpn", + "detectron2_maskrcnn_r_101_fpn", + "vision_maskrcnn", + } + def load_model( self, device, From 979edbbe128cae4c9eb26c6d4b38bd773707e154 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 15 Jun 2024 20:15:18 -0700 Subject: [PATCH 068/171] [Traceable FSDP2] Dynamo support FSDP2 use_training_state context manager (#127854) Improve Dynamo to support the FSDP2 `use_training_state()` context manager. Test command: ` pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_dynamo_trace_use_training_state ` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127854 Approved by: https://github.com/yanboliang --- .../fsdp/test_fully_shard_compile.py | 41 ++++++++++++++ torch/_dynamo/guards.py | 3 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/ctx_manager.py | 56 +++++++++++++++++++ torch/_dynamo/variables/functions.py | 12 ++++ torch/_dynamo/variables/torch.py | 14 +++++ .../_composable/fsdp/fully_shard.py | 2 +- torch/distributed/utils.py | 2 +- 8 files changed, 129 insertions(+), 2 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 8a87dfdd1d4de8..8834e5177d0a96 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -4,7 +4,10 @@ import unittest import torch +import torch._dynamo.testing from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._composable.fsdp._fsdp_common import TrainingState +from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLP from torch.testing._internal.common_utils import run_tests @@ -60,5 +63,43 @@ def patched_trace_rules_check(*args, **kwargs): self.assertTrue(trace_rules_check_count > 0) +class TestFullyShardCompile(FSDPTest): + def test_dynamo_trace_use_training_state(self): + torch._dynamo.reset() + # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager. + param_group = FSDPParamGroup( + [], # params: List[nn.Parameter], + torch.nn.Linear(1, 1), # module: nn.Module, + None, # mesh_info: FSDPMeshInfo, + None, # post_forward_mesh_info: Optional[FSDPMeshInfo], + None, # device: torch.device, + None, # mp_policy: MixedPrecisionPolicy, + None, # offload_policy: OffloadPolicy, + ) + + def f(x): + param_group._training_state = TrainingState.IDLE + with param_group.use_training_state(TrainingState.FORWARD): + if param_group._training_state == TrainingState.FORWARD: + return x + 1 + else: + return x + + inp = torch.zeros(1) + self.assertEqual(param_group._training_state, TrainingState.IDLE) + + eager_out = f(inp) + self.assertEqual(param_group._training_state, TrainingState.IDLE) + self.assertEqual(eager_out, inp + 1) + + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + compiled_out = torch.compile(f, backend=cnt, fullgraph=True)(inp) + self.assertEqual(param_group._training_state, TrainingState.IDLE) + self.assertEqual(eager_out, compiled_out) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + self.assertEqual(len(cnt.graphs), 1) + + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 0c882ad16fcffe..86ee856d907c83 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1663,6 +1663,9 @@ def DETERMINISTIC_ALGORITHMS(self, guard: Guard): def TORCH_FUNCTION_STATE(self, guard: Guard): pass # we always guard on this via GlobalStateGuard() + def FSDP_TRAINING_STATE(self, guard: Guard): + pass # we always guard on this via GlobalStateGuard() + def DEFAULT_DEVICE(self, guard: Guard): """Guard on CURRENT_DEVICE per torch.utils._device""" assert guard.source is GuardSource.GLOBAL diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 152698568c7c67..6a762260ffe11b 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -9,6 +9,7 @@ DeterministicAlgorithmsVariable, DisabledSavedTensorsHooksVariable, DualLevelContextManager, + FSDPParamGroupUseTrainingStateVariable, GradIncrementNestingCtxManagerVariable, GradInplaceRequiresGradCtxManagerVariable, GradModeVariable, diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 637636f1e04693..ecfb420d0180f3 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -843,6 +843,62 @@ def reconstruct(self, codegen): ) +class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) + + @staticmethod + def create(tx, param_group_var, target_value, **kwargs): + var = FSDPParamGroupUseTrainingStateVariable( + param_group_var=param_group_var, + target_values=[target_value], + initial_values=[param_group_var.value._training_state], + **kwargs, + ) + return var + + def __init__(self, param_group_var, target_values, initial_values=None, **kwargs): + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + self.param_group_var = param_group_var + install_guard(self._guards_singleton) + + def enter(self, tx): + self._call_func(tx, self.target_values) + return variables.ConstantVariable.create(None) + + def exit(self, tx, *args): + self._call_func(tx, self.initial_values) + return variables.ConstantVariable.create(None) + + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ): + self._call_func(tx, self.initial_values) # undo eager initialization + return super().call_function(tx, args, kwargs) + + def _call_func(self, tx, values): + assert len(values) == 1 + value = values[0] + if self.param_group_var.value._training_state != value: + self.param_group_var.call_method( + tx, + "__setattr__", + ( + variables.ConstantVariable.create("_training_state"), + variables.EnumVariable(value), + ), + {}, + ) + self.param_group_var.value._training_state = value + + def module_name(self): + return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup" + + def fn_name(self): + return "use_training_state" + + class StreamVariable(VariableTracker): def __init__(self, proxy, value, device, **kwargs): if proxy is not None and "example_value" in proxy.node.meta: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f54b73230ceba7..8ec99af2584a11 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -22,6 +22,11 @@ if TYPE_CHECKING: from torch._guards import Source +try: + from torch.distributed._composable.fsdp import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None + def wrap_bound_arg(tx, val, source=None): # Source propagation is best effort since not every object we encounter has a source to begin with. @@ -338,6 +343,13 @@ def call_function( return self.obj.call_method( tx, self.fn.__name__, args, kwargs, constant=self.is_constant ) + elif ( + _fsdp_param_group is not None + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + return variables.TorchCtxManagerClassVariable(self.fn).call_function( + tx, (self.obj, *args), kwargs + ) if self.is_constant: fn = getattr(self.obj.value, self.fn.__name__) return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 934e9a316a4bc4..2f73258248c92f 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -51,6 +51,11 @@ except ModuleNotFoundError: np = None # type: ignore[assignment] +try: + from torch.distributed._composable.fsdp import _fsdp_param_group +except ModuleNotFoundError: + _fsdp_param_group = None # type: ignore[assignment] + log = logging.getLogger(__name__) supported_ctx_manager_classes = dict.fromkeys( @@ -203,6 +208,7 @@ def call_function( from . import ( DisabledSavedTensorsHooksVariable, DualLevelContextManager, + FSDPParamGroupUseTrainingStateVariable, GradIncrementNestingCtxManagerVariable, GradInplaceRequiresGradCtxManagerVariable, GradModeVariable, @@ -300,6 +306,14 @@ def call_function( return DisabledSavedTensorsHooksVariable.create( tx, args[0].as_python_constant() ) + elif ( + _fsdp_param_group is not None + and self.value is _fsdp_param_group.FSDPParamGroup.use_training_state + ): + assert len(args) == 2 + return FSDPParamGroupUseTrainingStateVariable.create( + tx, args[0], args[1].as_python_constant() + ) return super().call_function(tx, args, kwargs) diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 32c28231ce14f3..d3e70b38eac919 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -24,7 +24,7 @@ # The decorator adds a state object to `module` that can be accessed via # `fully_shard.state(module)`. The state object and module are 1:1. -@contract(state_cls=FSDPState) +@contract(state_cls=FSDPState) # type: ignore[operator] def fully_shard( module: nn.Module, *, diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 7c135cbbacf892..f13d066415015b 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -281,7 +281,7 @@ def _to_kwargs( def _verify_param_shape_across_processes( process_group: dist.ProcessGroup, tensors: List[torch.Tensor], - logger: Optional[dist.Logger] = None, + logger: Optional["dist.Logger"] = None, ): return dist._verify_params_across_processes(process_group, tensors, logger) From f8d60e0e0a4420def0cf6b3bc0a0c0d46c93de1c Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 20:56:19 -0700 Subject: [PATCH 069/171] [Inductor][CPP] Fix Half data type cse cache issue for CPP Backend (#128498) **Summary** Fixing issue: https://github.com/pytorch/pytorch/issues/128263. After https://github.com/pytorch/pytorch/issues/115260, we cached the higher precision cse variable to avoid duplicate casting between buffers. However, it failed to check the original data type. This means if we convert `int32` to `bf16` for `store` and then convert `bf16` back to `fp32` for `load`, it would incorrectly hit the cache and reuse the `int32` cse var. This PR fixes the issue. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_issue_128263 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128498 Approved by: https://github.com/jgong5, https://github.com/zhuhaozhe, https://github.com/jerryzh168 --- test/inductor/test_cpu_repro.py | 26 +++++++++++++++++++++- torch/_inductor/codegen/cpp.py | 39 ++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index c63e07085b7f70..e499e3a8247683 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -18,7 +18,7 @@ from torch._C import FileCheck from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same -from torch._inductor import codecache, config, metrics +from torch._inductor import codecache, config, metrics, test_operators from torch._inductor.codegen.common import OptimizationContext from torch._inductor.codegen.cpp import ( CppOverrides, @@ -3659,6 +3659,30 @@ def fn(x): self.common(fn, (x,)) assert metrics.generated_cpp_vec_kernel_count == 1 + def test_highp_to_lowp_cse_var_cache_with_store(self): + # Fix issue: https://github.com/pytorch/pytorch/issues/128263 + input = torch.randn(5, 128, dtype=torch.float32) + input2 = torch.randint(0, 10, (5, 128), dtype=torch.int8) + input3 = torch.randn(128, 128, dtype=torch.float32) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, x2, x3): + x2 = x2.to(torch.int32) + temp = test_operators.realize(x2.to(torch.float16)) + temp2 = temp.to(torch.float32) + temp2 = temp2 * x + return torch.mm(temp, x3.to(torch.float16)), temp2 + + metrics.reset() + m = Model() + self.common( + m, + (input, input2, input3), + ) + def test_reduction_float_to_int64(self): # https://github.com/pytorch/pytorch/issues/124821 def fn(x): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6e2189ef6d65b8..6b8574b9268ad7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -188,12 +188,12 @@ def is_to_lowp_dtype(expr): return any(to_expr in expr for to_expr in to_exprs) -def get_lowp_to_fp32_expr(lowp_var, kernel): +def get_lowp_to_high_prec_expr(lowp_var, dtype, kernel): if isinstance(kernel, CppVecKernel): - return f"at::vec::convert({lowp_var})" + return f"at::vec::convert<{DTYPE_TO_CPP[dtype]}>({lowp_var})" else: assert isinstance(kernel, CppKernel) - return f"c10::convert({lowp_var})" + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({lowp_var})" index_value_name_counter = 1 @@ -1614,7 +1614,7 @@ def masked(self, mask): finally: self._load_mask = prior - def cache_fp32_cse_var_before_lowp_store(self, var_to_store): + def cache_high_prec_cse_var_before_lowp_store(self, var_to_store): """ https://github.com/pytorch/pytorch/issues/115260 For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is @@ -1651,26 +1651,29 @@ def cache_fp32_cse_var_before_lowp_store(self, var_to_store): # only need to cache fp32 cse var while var_to_store is lowp data return - def find_fp32_var(var, cache): - fp32_cse_var = None - fp32_cse_var_name = None + def find_high_prec_var(var, cache): + high_prec_cse_var = None + high_prec_cse_var_name = None for expr, cse_var in cache.items(): if cse_var == var: if is_to_lowp_dtype(expr): m = re.search(r"tmp\d+", expr) if m is not None: - fp32_cse_var_name = m.group() - if fp32_cse_var_name: + high_prec_cse_var_name = m.group() + if high_prec_cse_var_name: for cse_var in cache.values(): - if cse_var.name == fp32_cse_var_name: - fp32_cse_var = cse_var + if cse_var.name == high_prec_cse_var_name: + high_prec_cse_var = cse_var break - assert fp32_cse_var is not None - return fp32_cse_var + assert high_prec_cse_var is not None + return high_prec_cse_var - fp32_var = find_fp32_var(var_to_store, self.cse.cache) - if fp32_var: - self.cse.cache[get_lowp_to_fp32_expr(var_to_store, self)] = fp32_var + high_prec_var = find_high_prec_var(var_to_store, self.cse.cache) + if high_prec_var and high_prec_var.dtype in DTYPE_TO_CPP: + cache_key = get_lowp_to_high_prec_expr( + var_to_store, high_prec_var.dtype, self + ) + self.cse.cache[cache_key] = high_prec_var def scale_index_with_offset( self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 @@ -1749,7 +1752,7 @@ def load(self, name: str, index: sympy.Expr): def store(self, name, index, value, mode=None): assert "buf" in name var = self.args.output(name) - self.cache_fp32_cse_var_before_lowp_store(value) + self.cache_high_prec_cse_var_before_lowp_store(value) index = self.rename_indexing(index) if mode is None: line = f"{var}[{cexpr_index(index)}] = {value};" @@ -2344,7 +2347,7 @@ def store(self, name, index, value, mode=None): value = self.broadcast(value) opt_ctx: OptimizationContext = get_current_node_opt_ctx() var = self.args.output(name) - self.cache_fp32_cse_var_before_lowp_store(value) + self.cache_high_prec_cse_var_before_lowp_store(value) index = self.rename_indexing(index) code = self._get_store_line(value, var, index, V.graph.get_dtype(name)) self.stores.splice(code.map(lambda x: DeferredLine(name, x))) From 6cbdbb6c3c240e2cfac39eeef6eb3d55ed3c1768 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Sun, 16 Jun 2024 16:34:12 +0000 Subject: [PATCH 070/171] Remove top lev numpy dependency from fuzzer.py (#128759) Test CI This fixes issues like this where I don't even intend to use the fuzzer. this way if someone is calling functions from the fuzzer numpy will be imported otherwise the import should not happen at the top of the file ``` >>> import torchao Traceback (most recent call last): File "", line 1, in File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torchao/__init__.py", line 26, in from torchao.quantization import ( File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torchao/quantization/__init__.py", line 7, in from .smoothquant import * # noqa: F403 File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torchao/quantization/smoothquant.py", line 18, in import torchao.quantization.quant_api as quant_api File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 23, in from torchao.utils import ( File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torchao/utils.py", line 2, in import torch.utils.benchmark as benchmark File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torch/utils/benchmark/__init__.py", line 4, in from torch.utils.benchmark.utils.fuzzer import * # noqa: F403 File "/home/marksaroufim/anaconda3/envs/fresh/lib/python3.10/site-packages/torch/utils/benchmark/utils/fuzzer.py", line 5, in import numpy as np ModuleNotFoundError: No module named 'numpy' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128759 Approved by: https://github.com/Skylion007 --- torch/utils/benchmark/utils/fuzzer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index 08206efce377e1..5f69196960c26e 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -3,7 +3,6 @@ import itertools as it from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch @@ -103,6 +102,7 @@ def _check_distribution(self, distribution): return distribution def _loguniform(self, state): + import numpy as np output = int(2 ** state.uniform( low=np.log2(self._minval) if self._minval is not None else None, high=np.log2(self._maxval) if self._maxval is not None else None, @@ -119,6 +119,7 @@ def _uniform(self, state): return state.uniform(low=self._minval, high=self._maxval) def _custom_distribution(self, state): + import numpy as np # If we directly pass the keys to `choice`, numpy will convert # them to numpy dtypes. index = state.choice( @@ -266,6 +267,7 @@ def default_tensor_constructor(size, dtype, **kwargs): return torch.randint(1, 127, size=size, dtype=dtype, device="cpu") def _make_tensor(self, params, state): + import numpy as np size, steps, allocation_size = self._get_size_and_steps(params) constructor = ( self._tensor_constructor or @@ -369,6 +371,7 @@ def __init__( also be used to set the PyTorch random seed so that random ops will create reproducible Tensors. """ + import numpy as np if seed is None: seed = np.random.RandomState().randint(0, 2 ** 32 - 1, dtype=np.int64) self._seed = seed @@ -392,6 +395,7 @@ def _unpack(values, cls): )) def take(self, n): + import numpy as np state = np.random.RandomState(self._seed) torch.manual_seed(state.randint(low=0, high=2 ** 63, dtype=np.int64)) for _ in range(n): From f9dae86222aaf15ea085c7774da70781bae46ff9 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 16 Jun 2024 23:51:14 +0000 Subject: [PATCH 071/171] Concat namespaces in torch/csrc/utils/* (#128787) Concat namespaces in torch/csrc/utils/* Pull Request resolved: https://github.com/pytorch/pytorch/pull/128787 Approved by: https://github.com/Skylion007 --- torch/csrc/utils/byte_order.cpp | 6 ++---- torch/csrc/utils/init.cpp | 6 ++---- torch/csrc/utils/nested.cpp | 6 ++---- torch/csrc/utils/out_types.cpp | 6 ++---- torch/csrc/utils/pybind.cpp | 6 ++---- torch/csrc/utils/schema_info.cpp | 6 ++---- torch/csrc/utils/structseq.cpp | 6 ++---- torch/csrc/utils/throughput_benchmark.cpp | 6 ++---- 8 files changed, 16 insertions(+), 32 deletions(-) diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index 6b91d0665394ee..c10dbfbb786b1d 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -110,8 +110,7 @@ static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) { } // anonymous namespace -namespace torch { -namespace utils { +namespace torch::utils { THPByteOrder THP_nativeByteOrder() { uint32_t x = 1; @@ -481,5 +480,4 @@ void THP_encodeComplexDoubleBuffer( } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 7caa1f8ad437a9..391b331c4f10ca 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { void initThroughputBenchmarkBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -53,5 +52,4 @@ void initThroughputBenchmarkBindings(PyObject* module) { }); } -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark diff --git a/torch/csrc/utils/nested.cpp b/torch/csrc/utils/nested.cpp index 29ccf312851ea1..e34ee21b031553 100644 --- a/torch/csrc/utils/nested.cpp +++ b/torch/csrc/utils/nested.cpp @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs static c10::TensorOptions typeIdWithDefault( @@ -87,5 +86,4 @@ at::Tensor nested_tensor_ctor( return out; } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/out_types.cpp b/torch/csrc/utils/out_types.cpp index 7e712f20871690..6dad9c91c18c96 100644 --- a/torch/csrc/utils/out_types.cpp +++ b/torch/csrc/utils/out_types.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { // Used by python binding codegen to ensure any TensorOptions arguments are // consistent with the out tensor's options @@ -45,5 +44,4 @@ void check_out_type_matches( } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/pybind.cpp b/torch/csrc/utils/pybind.cpp index 57c10c694861e1..9433e1540be5b1 100644 --- a/torch/csrc/utils/pybind.cpp +++ b/torch/csrc/utils/pybind.cpp @@ -2,8 +2,7 @@ #include #include -namespace pybind11 { -namespace detail { +namespace pybind11::detail { bool type_caster::load(py::handle src, bool) { if (torch::is_symint(src)) { @@ -164,5 +163,4 @@ py::handle type_caster::cast( } } -} // namespace detail -} // namespace pybind11 +} // namespace pybind11::detail diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp index 0caa5b254d279f..34ddfe5a8c5ef8 100644 --- a/torch/csrc/utils/schema_info.cpp +++ b/torch/csrc/utils/schema_info.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace utils { +namespace torch::utils { void SchemaInfo::addArgumentValue( const std::string& name, const at::IValue& value) { @@ -433,5 +432,4 @@ void SchemaInfo::generateAliasMaps() { } } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp index 5533516193cce3..f23af7bf31f52b 100644 --- a/torch/csrc/utils/structseq.cpp +++ b/torch/csrc/utils/structseq.cpp @@ -18,8 +18,7 @@ #include -namespace torch { -namespace utils { +namespace torch::utils { // NOTE: The built-in repr method from PyStructSequence was updated in // https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79 @@ -72,5 +71,4 @@ PyObject* returned_structseq_repr(PyStructSequence* obj) { return PyUnicode_FromString(ss.str().c_str()); } -} // namespace utils -} // namespace torch +} // namespace torch::utils diff --git a/torch/csrc/utils/throughput_benchmark.cpp b/torch/csrc/utils/throughput_benchmark.cpp index 7398d5f519df7c..f07ff86e98cce9 100644 --- a/torch/csrc/utils/throughput_benchmark.cpp +++ b/torch/csrc/utils/throughput_benchmark.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace throughput_benchmark { +namespace torch::throughput_benchmark { std::ostream& operator<<( std::ostream& os, @@ -137,5 +136,4 @@ ScriptModuleInput cloneInput( } // namespace detail -} // namespace throughput_benchmark -} // namespace torch +} // namespace torch::throughput_benchmark From 74e11a4210158b85356b9905bb82972b3d268435 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 17 Jun 2024 02:19:48 +0000 Subject: [PATCH 072/171] Enable clang-tidy on torch/csrc/mps (#128782) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128782 Approved by: https://github.com/Skylion007 --- .lintrunner.toml | 1 - torch/csrc/mps/Module.cpp | 10 +++------- torch/csrc/mps/Module.h | 6 ++---- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index d4e81d1e68a641..3cd46419798b89 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -235,7 +235,6 @@ exclude_patterns = [ 'torch/csrc/jit/serialization/import_legacy.cpp', 'torch/csrc/jit/serialization/export.cpp', 'torch/csrc/lazy/**/*', - 'torch/csrc/mps/**/*', ] init_command = [ 'python3', diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 415fda6165ddeb..468d0bf2e5d2f1 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -11,8 +11,7 @@ #include #endif -namespace torch { -namespace mps { +namespace torch::mps { namespace { // True for children forked after mps init @@ -225,9 +224,7 @@ static PyObject* MPSModule_elapsedTimeOfEvents( END_HANDLE_TH_ERRORS } -// NOLINTNEXTLINE(modernize-avoid-c-arrays, -// cppcoreguidelines-avoid-non-const-global-variables, -// cppcoreguidelines-avoid-c-arrays) +// NOLINTNEXTLINE(*-c-arrays, *-global-variables) static struct PyMethodDef _MPSModule_methods[] = { {"_mps_deviceSynchronize", MPSModule_deviceSynchronize, @@ -281,5 +278,4 @@ PyMethodDef* python_functions() { return _MPSModule_methods; } -} // namespace mps -} // namespace torch +} // namespace torch::mps diff --git a/torch/csrc/mps/Module.h b/torch/csrc/mps/Module.h index 3759d36d738b35..1eafc637e408f0 100644 --- a/torch/csrc/mps/Module.h +++ b/torch/csrc/mps/Module.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace mps { +namespace torch::mps { PyMethodDef* python_functions(); -} // namespace mps -} // namespace torch +} // namespace torch::mps From a52c8ace98afe76dc9e2c330b415972fd1529077 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Sun, 16 Jun 2024 05:15:16 +0000 Subject: [PATCH 073/171] [3/N] Non-Tensor: Support string parameter for aten operations (#125831) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125831 Approved by: https://github.com/jansel, https://github.com/jgong5 --- test/inductor/test_torchinductor.py | 37 ++++++++++++++++++ torch/_inductor/aoti_eager.py | 38 +++++++++++-------- .../inductor/aoti_eager/kernel_holder.cpp | 10 +++++ .../inductor/aoti_eager/kernel_meta_info.cpp | 10 +++++ .../inductor/aoti_eager/kernel_meta_info.h | 9 ++++- 5 files changed, 87 insertions(+), 17 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 54915479eb5c98..38b1abc6aa085b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -823,6 +823,43 @@ def test_aoti_eager_support_out(self): self.assertEqual(ref_tensor1, res_tensor1) self.assertEqual(ref_out_tensor1, res_out_tensor1) + @skipCUDAIf(not SM80OrLater, "Requires sm80") + def test_aoti_eager_support_str(self): + ns = "aten" + op_name = "div" + dispatch_key = "CPU" + device = "cpu" + if self.device.lower() == "cuda": + dispatch_key = "CUDA" + device = "cuda" + + a = torch.randn(128, dtype=torch.float, device=device) + b = torch.randn(128, dtype=torch.float, device=device) + rounding_mode_list = ["trunc", "floor"] + with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: + # Get ref result from eager + ref_value_list = [] + for rounding_mode in rounding_mode_list: + ref_value = getattr(torch.ops.aten, op_name)( + a, b, rounding_mode=rounding_mode + ) + ref_value_list.append(ref_value) + + register_ops_with_aoti_compile( + ns, [op_name], dispatch_key, torch_compile_op_lib_impl + ) + + # Invoke the pre-compiled kernel and get result. + res_value_list = [] + for rounding_mode in rounding_mode_list: + res_value = getattr(torch.ops.aten, op_name)( + a, b, rounding_mode=rounding_mode + ) + res_value_list.append(res_value) + + for ref_value, res_value in zip(ref_value_list, res_value_list): + self.assertEqual(ref_value, res_value) + @skipCUDAIf(not SM80OrLater, "Requires sm80") def test_aoti_eager_cache_hit(self): ns = "aten" diff --git a/torch/_inductor/aoti_eager.py b/torch/_inductor/aoti_eager.py index d77c764a00e129..fc327011d916ec 100644 --- a/torch/_inductor/aoti_eager.py +++ b/torch/_inductor/aoti_eager.py @@ -1,7 +1,7 @@ import json import os from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest import mock import torch @@ -46,12 +46,14 @@ def load_aoti_eager_cache( return [] for metadata in item["meta_info"]: - assert not metadata[ - "is_dynamic" - ], "Only support static shape for now" - if metadata["device_type"] == "cpu": + if "is_dynamic" in metadata and metadata["is_dynamic"]: + raise NotImplementedError("Only support static shape for now") + if "device_type" in metadata and metadata["device_type"] == "cpu": metadata["device_index"] = -1 - metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1]) + if "dtype" in metadata: + metadata["dtype"] = getattr( + torch, metadata["dtype"].split(".")[-1] + ) return json_data @@ -62,8 +64,7 @@ def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: def supported_scalar_types() -> Tuple[type, ...]: type_to_torch_dtype = supported_builtin_dtype_torch_dtype() - supported_scalar_types = tuple(type_to_torch_dtype.keys()) - return supported_scalar_types + return tuple(type_to_torch_dtype.keys()) def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: @@ -98,9 +99,7 @@ def extract_tensor_list_metadata( return metadata -def extract_scalar_metadata( - device_type: str, input: Union[int, float, bool] -) -> Dict[str, Any]: +def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]: assert isinstance(input, supported_scalar_types()) metadata: Dict[str, Any] = {} metadata["is_dynamic"] = False @@ -113,6 +112,13 @@ def extract_scalar_metadata( return metadata +def extract_string_metadata(input: str) -> Dict[str, Any]: + assert isinstance(input, str) + metadata: Dict[str, Any] = {} + metadata["string_value"] = input + return metadata + + def aoti_compile_with_persistent_cache( ns: str, op_func_name_with_overload: str, @@ -131,11 +137,9 @@ def aoti_compile_with_persistent_cache( Compile the given function with persistent cache for AOTI eager mode. """ assert not dynamic, "Only support static shape for now" - type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} - supported_scalar_types = tuple(type_to_torch_dtype.keys()) flattened_inputs = list(args) + list(kwargs.values()) if not all( - isinstance(input, (supported_scalar_types, torch.Tensor, list)) + isinstance(input, (supported_scalar_types(), torch.Tensor, list, str)) for input in flattened_inputs ): raise NotImplementedError( @@ -184,8 +188,12 @@ def aoti_compile_with_persistent_cache( elif isinstance(input, list): assert all(isinstance(item, torch.Tensor) for item in input) metadata = extract_tensor_list_metadata(dynamic, input) - else: + elif isinstance(input, supported_scalar_types()): metadata = extract_scalar_metadata(device_type, input) + elif isinstance(input, str): + metadata = extract_string_metadata(input) + else: + raise NotImplementedError(f"Unsupported input type: {type(input)}") metadata["arg_order"] = idx kernel_metadata_items.append(metadata) diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index f93b0fb2a9da19..e0ecea6c3985c3 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -147,6 +147,9 @@ std::vector unpack_input_parameters( } } else if (stack[idx].isTensor()) { inputs_metadata.push_back(ParameterMetadata(stack[idx].toTensor(), idx)); + } else if (stack[idx].isString()) { + inputs_metadata.push_back( + ParameterMetadata(stack[idx].toStringRef(), idx)); } else { TORCH_CHECK_NOT_IMPLEMENTED( false, @@ -309,6 +312,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { uint64_t arg_idx = metadata["arg_order"].cast(); bool is_scalar = metadata.contains("scalar_value"); bool is_tensor_list = metadata.contains("tensor_list"); + bool is_string = metadata.contains("string_value"); if (is_tensor_list) { // Tensor List @@ -332,6 +336,12 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { auto scalar_value = metadata["scalar_value"].cast(); parameter_metadata_list.push_back( ParameterMetadata(c10::Scalar(scalar_value), arg_idx)); + } else if (is_string) { + // String + auto metadata = item_metadata.cast(); + auto str_value = metadata["string_value"].cast(); + parameter_metadata_list.push_back( + ParameterMetadata(str_value, arg_idx)); } else { // Tensor auto metadata = item_metadata.cast(); diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index 95cc29b412c1fa..801ea59088a55c 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -151,6 +151,13 @@ ParameterMetadata::ParameterMetadata( value_ = scalar; } +ParameterMetadata::ParameterMetadata( + const std::string& str, + uint64_t input_order) + : tag_(STRING), order_(input_order) { + value_ = str; +} + bool ParameterMetadata::operator==(const ParameterMetadata& other) const { // Same type if (tag_ != other.tag_) { @@ -174,6 +181,9 @@ bool ParameterMetadata::operator==(const ParameterMetadata& other) const { std::get(other.value_).isFloatingPoint() || std::get(other.value_).isIntegral(true /*includeBool*/)); return equal_to(std::get(other.value_)); + case STRING: + return std::get(value_) == + std::get(other.value_); default: return false; } diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index d07814dd0ad9ca..c5a858ff8233e3 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -89,13 +89,17 @@ enum ParameterTag { TENSOR_LIST, TENSOR_LIST_OPTIONAL, SCALAR, + STRING, INVALID, }; // ParameterMetadataValue is to represent the value of the input parameters of a // aten operation. -using ParameterMetadataValue = - std::variant, c10::Scalar>; +using ParameterMetadataValue = std::variant< + TensorMetadata, + std::vector, + c10::Scalar, + std::string>; // ParameterMetadata is to represent the metadata of the input parameters of a // aten operation. It includes the tag of the parameter, the value of the @@ -122,6 +126,7 @@ struct ParameterMetadata { const std::vector& tensor_metadata_list, uint64_t input_order); ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order); + ParameterMetadata(const std::string& string_value, uint64_t input_order); bool operator==(const ParameterMetadata& other) const; From b40a033c380e44cdbb2b3f8931d04c8eedbe8fb3 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 17 Jun 2024 05:44:34 +0000 Subject: [PATCH 074/171] [cpp_extension][inductor] Fix sleef windows depends. (#128770) # Issue: During I'm working on enable inductor on PyTorch Windows, I found the sleef lib dependency issue. image # Analysis: After we enabled SIMD on PyTorch Windows(https://github.com/pytorch/pytorch/pull/118980 ), the sleef functions are called from VEC headers. It bring the sleef to the dependency. Here is a different between Windows and Linux OS. ## Linux : Linux is default export its functions, so libtorch_cpu.so static link to sleef.a, and then It also export sleef's functions. image ## Windows: Windows is by default not export its functions, and have many limitation to export functions, reference: https://github.com/pytorch/pytorch/issues/80604 We can't package sleef functions via torch_cpu.dll like Linux. # Solution: Acturally, we also packaged sleef static lib as a part of release. We just need to help user link to sleef.lib, it should be fine. 1. Add sleef to cpp_builder for inductor. 2. Add sleef to cpp_extension for C++ extesion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128770 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/cpp_builder.py | 3 +++ torch/utils/cpp_extension.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 413270edc314de..f75f079d72db2b 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -572,6 +572,9 @@ def _get_torch_related_args(include_pytorch: bool, aot_mode: bool): if not aot_mode: libraries.append("torch_python") + if _IS_WINDOWS: + libraries.append("sleef") + # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 if not config.abi_compatible: libraries.append("c10") diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 1904f8c3ecae0f..bc1a9d8e6c0f8e 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -961,6 +961,9 @@ def CppExtension(name, sources, *args, **kwargs): libraries.append('torch') libraries.append('torch_cpu') libraries.append('torch_python') + if IS_WINDOWS: + libraries.append("sleef") + kwargs['libraries'] = libraries kwargs['language'] = 'c++' From b0282071c48860fcf8f4c1025bc207138173617b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 15 Jun 2024 20:19:42 -0700 Subject: [PATCH 075/171] [dynamo] override torch.nn.modules.activation._is_make_fx_tracing (#128748) Discovered while inlining `MultiHeadAttention` nn Module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128748 Approved by: https://github.com/jansel ghstack dependencies: #128315 --- test/dynamo/test_repros.py | 8 ++++++++ torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/variables/torch.py | 1 + 3 files changed, 10 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c53f8a49d6a64f..2329ab305e763c 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5224,6 +5224,14 @@ def forward(self, x): opt_mod = torch.compile(mod, backend=compiler) opt_mod(torch.randn(2, 2)) + def test_is_make_fx_tracing(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + torch.nn.modules.activation._is_make_fx_tracing() + return torch.sin(x) + + fn(torch.rand(4)) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index a217d159c5d10c..abbef02e63c682 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2766,6 +2766,7 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._is_make_fx_tracing", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 2f73258248c92f..74c2193646bc0b 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -131,6 +131,7 @@ torch._utils.is_compiling: True, torch.compiler.is_compiling: True, torch.compiler.is_dynamo_compiling: True, + torch.nn.modules.activation._is_make_fx_tracing: False, } bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) From 0f81473d7b4a1bf09246410712df22541be7caf3 Mon Sep 17 00:00:00 2001 From: Ambareesh Shyam Sundar Date: Mon, 17 Jun 2024 13:41:15 +0000 Subject: [PATCH 076/171] Update fake tensor error checks for bool tensor subtraction (#128492) Fixes #127003 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128492 Approved by: https://github.com/soulitzer --- test/test_binary_ufuncs.py | 2 -- torch/_refs/__init__.py | 9 +++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index ffa3e53889794c..b6bba041d4f9fa 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3694,8 +3694,6 @@ def test_addsub_half_tensor(self, device): actual = op(x, y, alpha=alpha) self.assertTrue(not (actual.isnan() or actual.isinf())) - # https://github.com/pytorch/pytorch/issues/127003 - @xfailIfTorchDynamo def test_sub_typing(self, device): m1 = torch.tensor( [True, False, False, True, False, False], dtype=torch.bool, device=device diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index be32e5b9ad289c..707e63cea0be78 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1712,6 +1712,15 @@ def sub( a, b = _maybe_broadcast(a, b) + if isinstance(a, TensorLike) and isinstance(b, TensorLike): + torch._check( + not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype), + lambda: ( + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead." + ), + ) + if alpha != 1: dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] python_type = utils.dtype_to_type(dtype) From e3093849e5530dbb93a35462d3fd248a2dd7efe0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 17 Jun 2024 14:55:32 +0000 Subject: [PATCH 077/171] [Docs] Update links (#128795) From https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding to https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html And from https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag to https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html Fixes https://github.com/pytorch/pytorch/issues/128774 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128795 Approved by: https://github.com/atalman --- torch/ao/nn/quantized/modules/embedding_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index 43b8d65063a430..7418aa38307c15 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -72,7 +72,7 @@ class Embedding(torch.nn.Module): r""" A quantized Embedding module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.Embedding`, please see - https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation. + https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation. Similar to :class:`~torch.nn.Embedding`, attributes will be randomly initialized at module creation time and will be overwritten later @@ -196,7 +196,7 @@ class EmbeddingBag(Embedding): r""" A quantized EmbeddingBag module with quantized packed weights as inputs. We adopt the same interface as `torch.nn.EmbeddingBag`, please see - https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation. + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation. Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly initialized at module creation time and will be overwritten later From 24443fe16ab2fba774702351b767bfc49b7b6407 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Sat, 15 Jun 2024 11:19:41 -0700 Subject: [PATCH 078/171] [inductor] parallel compile: Print traceback detail when there's an exception in a sub-process (#128775) Summary: We lose traceback info when an exception occurs in a subprocess because Python traceback objects don't pickle. In the subprocess-based parallel compile, we _are_ logging an exception in the subprocess, but a) those messages are easy to miss because they're not in the traceback output, and b) it seems that logging in the subproc is swallowed by default in internal builds. This PR captures the traceback in the subprocess and makes it available in the exception thrown in the main process. Users now see failures that look like this: ``` ... File "/home/slarsen/.conda/envs/pytorch-3.10_3/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/home/slarsen/.conda/envs/pytorch-3.10_3/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: SubprocException: An exception occurred in a subprocess: Traceback (most recent call last): File "/data/users/slarsen/pytorch-3.10_3/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job result = SubprocMain.foo() File "/data/users/slarsen/pytorch-3.10_3/torch/_inductor/compile_worker/subproc_pool.py", line 263, in foo SubprocMain.bar() File "/data/users/slarsen/pytorch-3.10_3/torch/_inductor/compile_worker/subproc_pool.py", line 260, in bar SubprocMain.baz() File "/data/users/slarsen/pytorch-3.10_3/torch/_inductor/compile_worker/subproc_pool.py", line 257, in baz raise Exception("an error occurred") Exception: an error occurred ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128775 Approved by: https://github.com/jansel --- test/inductor/test_compile_worker.py | 7 ++-- .../_inductor/compile_worker/subproc_pool.py | 34 +++++++++++++++++-- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 61ee25421847c6..527da30f32a155 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -4,8 +4,8 @@ from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, + SubprocException, SubprocPool, - TestException, ) from torch._inductor.test_case import TestCase @@ -27,7 +27,10 @@ def test_exception(self): pool = SubprocPool(2) try: a = pool.submit(raise_testexc) - with self.assertRaises(TestException): + with self.assertRaisesRegex( + SubprocException, + "torch._inductor.compile_worker.subproc_pool.TestException", + ): a.result() finally: pool.shutdown() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 5bde798d88416e..ce3da87fced847 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -9,6 +9,7 @@ import subprocess import sys import threading +import traceback import typing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool @@ -74,6 +75,26 @@ def _get_ld_library_path(): return path +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details): + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details): + super().__init__(f"An exception occurred in a subprocess:\n\n{details}") + + class SubprocPool: """ Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in @@ -147,7 +168,13 @@ def _read_thread(self): with self.futures_lock: if not self.running: return - if isinstance(result, Exception): + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. self.pending_futures[job_id].set_exception(result) else: self.pending_futures[job_id].set_result(result) @@ -247,7 +274,10 @@ def callback(_): def do_job(data): # do the pickle/unpickle in the sub-subproc job = pickle.loads(data) - result = job() + try: + result = job() + except Exception as e: + result = _SubprocExceptionInfo(traceback.format_exc()) return pickle.dumps(result, pickle.HIGHEST_PROTOCOL) From 2a41fc03903de63270d325bd1886a50faf32d7e4 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Fri, 14 Jun 2024 10:01:23 -0400 Subject: [PATCH 079/171] Short-term fix to preserve NJT metadata cache in torch.compile (#122836) Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile. For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors. **NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.** Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836 Approved by: https://github.com/soulitzer ghstack dependencies: #127007, #128057 --- aten/src/ATen/FunctionalInverses.cpp | 9 +- aten/src/ATen/native/native_functions.yaml | 14 +- test/dynamo/test_subclasses.py | 6 +- ...asDecompTest.test_has_decomposition.expect | 2 + test/test_nestedtensor.py | 173 ++++++++++++++++- tools/autograd/derivatives.yaml | 4 +- torch/nested/_internal/nested_tensor.py | 174 ++++++++++++++---- torch/nested/_internal/ops.py | 37 +++- torch/nested/_internal/sdpa.py | 62 +++++-- 9 files changed, 412 insertions(+), 69 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 16b59333f918fb..a1cf449cde7c7f 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, return Tensor(); } -Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx) { +Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const c10::optional& min_seqlen, const c10::optional& max_seqlen) { auto values = at::_nested_get_values(mutated_view); if (inverse_return_mode != InverseReturnMode::NeverView) { return values; @@ -317,7 +317,12 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const auto lengths = at::_nested_get_lengths(base); auto ragged_idx = at::_nested_get_ragged_idx(base); auto dummy = at::_nested_get_jagged_dummy(base); - auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx); + auto min_seqlen = at::_nested_get_min_seqlen(base); + auto max_seqlen = at::_nested_get_max_seqlen(base); + auto nt = at::_nested_view_from_jagged( + mutated_view, offsets, dummy, lengths, ragged_idx, + (min_seqlen.defined() ? c10::optional(min_seqlen) : c10::nullopt), + (max_seqlen.defined() ? c10::optional(max_seqlen) : c10::nullopt)); if (inverse_return_mode != InverseReturnMode::NeverView) { return nt; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f0f1ad78f8ffef..0715714a4d2d71 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6185,12 +6185,12 @@ CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy autogen: _nested_view_from_buffer_copy.out -- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) +- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) variants: function device_check: NoCheck dispatch: {} -- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor +- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor variants: function device_check: NoCheck tags: view_copy @@ -6227,6 +6227,16 @@ device_check: NoCheck dispatch: {} +- func: _nested_get_min_seqlen(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + +- func: _nested_get_max_seqlen(Tensor self) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + - func: _nested_get_jagged_dummy(Tensor any) -> Tensor category_override: dummy dispatch: {} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 302b07e4ddb78b..f16ef15990fd8c 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1616,15 +1616,15 @@ def backend(gm, args): guard_str, """\ Eq(s3 - 1, s0) -Eq(zf1, zf4)""", +Eq(zf1, zf6)""", ) else: self.assertExpectedInline( guard_str, """\ Eq(s4 - 1, s1) -Eq(s10 - 1, s5) -Eq(s9, s7)""", +Eq(s12 - 1, s7) +Eq(s11, s9)""", ) return gm diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 1179142e15d9e7..132d25a8b12f98 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -446,6 +446,8 @@ aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out aten::_nested_get_jagged_dummy aten::_nested_get_lengths +aten::_nested_get_max_seqlen +aten::_nested_get_min_seqlen aten::_nested_get_offsets aten::_nested_get_ragged_idx aten::_nested_get_values diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 78d082702aecb0..86f58b5a0de3a0 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -67,6 +67,21 @@ def _iter_constructors(): yield torch.nested.nested_tensor +# Returns True if the function recompiles between inputs1 and inputs2 with the +# specified dynamic setting. +def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): + compile_count = [0] + + def counter(gm, example_inputs): + compile_count[0] += 1 + return gm + + compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) + compiled_f(*inputs1) + compiled_f(*inputs2) + return compile_count[0] > 1 + + # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -4818,19 +4833,18 @@ def fn(values, same_size): check_results(fn, compiled_fn, generate_inp(20)) self.assertEqual(compile_counter.frame_count, frame_count_2) - # Doesn't work until we have real views - @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT @unittest.skipIf( TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @parametrize( - "dtype", - [torch.float16, torch.bfloat16, torch.float32] - if SM80OrLater - else [torch.float16, torch.float32], + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) ) def test_sdpa(self, device, dtype): batch_size = 1 @@ -5173,8 +5187,6 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): ) self.assertEqual(output._values, output_dense) - # Doesn't work until we have real views - @xfailIfTorchDynamo @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, @@ -5451,6 +5463,149 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): padded, [offsets_wrong], total_L ) + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_compile_preserves_metadata_cache(self, device, dtype): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def f(nt): + q = nt.transpose(-3, -2) + output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) + return output + + output = f(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_compile_with_dynamic_max_seq_len(self, device, dtype): + # shape (B, *, D) + # max seq len: 18 + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(18, 5), + ], + layout=torch.jagged, + ) + + # max seq len: 19 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(19, 5), + ], + layout=torch.jagged, + ) + + def f(nt): + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt) * nt._get_max_seqlen() + + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_compile_with_dynamic_min_seq_len(self, device, dtype): + # shape (B, *, D) + # min seq len: 7 + nt = torch.nested.nested_tensor( + [ + torch.randn(7, 5), + torch.randn(8, 5), + torch.randn(9, 5), + ], + layout=torch.jagged, + ) + + # min seq len: 8 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(8, 5), + torch.randn(9, 5), + torch.randn(10, 5), + ], + layout=torch.jagged, + ) + + def f(nt): + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt) * nt._get_min_seqlen() + + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): + # shape (B, *, D) + # max seq len: 18 + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(18, 5), + ], + layout=torch.jagged, + ) + + # max seq len: 19 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(19, 5), + ], + layout=torch.jagged, + ) + + def f(nt): + nt2 = nt.sin() + 1 + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt2) * nt2._get_max_seqlen() + + ref = f(nt) + output = torch.compile(f, fullgraph=True, dynamic=False)(nt) + self.assertEqual(ref, output) + + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1e9b9091a20e94..76a7a0a1e42a4f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2794,14 +2794,14 @@ nested_size: non_differentiable nested_strides: non_differentiable -- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) self: grad.values() offsets: non_differentiable lengths: non_differentiable dummy: non_differentiable - name: _nested_get_values(Tensor(a) self) -> Tensor(a) - self: _nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self)) + self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional(at::_nested_get_min_seqlen(self)) : c10::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional(at::_nested_get_max_seqlen(self)) : c10::nullopt)" # Transformers - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 66d25eacc7ad4b..92423cf32b2fe8 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -27,6 +27,15 @@ def _get_sdpa_extreme_seqlen(func, tensor): return int(func(tensor).item()) +def _store_val_in_tensor(val) -> torch.Tensor: + # hack to get dynamic shapes support: store in a (val, 0) shaped tensor + return torch.zeros(val, 0) + + +def _load_val_from_tensor(t: torch.Tensor): + return t.shape[0] + + class NestedTensor(torch.Tensor): _values: torch.Tensor # type: ignore[assignment] _offsets: torch.Tensor @@ -122,6 +131,14 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) + # min / max sequence length should be dynamic if present + max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) + if max_seqlen_tensor is not None: + torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) + min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) + if min_seqlen_tensor is not None: + torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) + def values(self): # dispatch to get proper view relationship return torch._nested_get_values(self) # type: ignore[attr-defined] @@ -132,25 +149,56 @@ def offsets(self): def lengths(self): return self._lengths - @property - def _max_seqlen(self): - if "max_seqlen" not in self._metadata_cache: + # Private accessor functions for min / max sequence length. They're + # purposefully not @properties because those don't work with PT2 (yet). + # These compute / cache if not present. + # TODO: Revisit this when @properties are better supported by PT2. I think the ideal + # state would be to have public @properties for min / max sequence length that compile + # (including setters). + def _get_max_seqlen(self): + max_seqlen_tensor = self._max_seqlen_tensor + if max_seqlen_tensor is None: # compute & cache - self._metadata_cache["max_seqlen"] = _get_sdpa_extreme_seqlen( + max_val = _get_sdpa_extreme_seqlen( torch.max, self._offsets.diff() if self._lengths is None else self._lengths, ) - return self._metadata_cache["max_seqlen"] + max_seqlen_tensor = _store_val_in_tensor(max_val) + self._metadata_cache["max_seqlen"] = max_seqlen_tensor + return _load_val_from_tensor(max_seqlen_tensor) - @property - def _min_seqlen(self): - if "min_seqlen" not in self._metadata_cache: + def _get_min_seqlen(self): + min_seqlen_tensor = self._min_seqlen_tensor + if min_seqlen_tensor is None: # compute & cache - self._metadata_cache["min_seqlen"] = _get_sdpa_extreme_seqlen( + min_val = _get_sdpa_extreme_seqlen( torch.min, self._offsets.diff() if self._lengths is None else self._lengths, ) - return self._metadata_cache["min_seqlen"] + min_seqlen_tensor = _store_val_in_tensor(min_val) + self._metadata_cache["min_seqlen"] = min_seqlen_tensor + return _load_val_from_tensor(min_seqlen_tensor) + + # Private accessors used for treating min / max seqlen as inner tensors for + # flatten / unflatten. These must be properties to work with the traceable wrapper + # subclass logic. These do not compute / cache if not present. + @property + def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: + return self._metadata_cache.get("max_seqlen", None) + + @property + def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: + return self._metadata_cache.get("min_seqlen", None) + + # These are old private @property accessors that are kept around for internal BC + # reasons. TODO: Remove these! + @property + def _max_seqlen(self): + return self._get_max_seqlen() + + @property + def _min_seqlen(self): + return self._get_min_seqlen() def __repr__(self): # We should implement this in torch/_tensor_str.py instead @@ -170,6 +218,7 @@ def __reduce_ex__(self, proto): del state["_size"] del state["_strides"] + # TODO: Update this to handle the other inner tensors func = NestedTensor args = (self._values, self._offsets) return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) @@ -177,22 +226,33 @@ def __reduce_ex__(self, proto): def __tensor_flatten__(self): ctx = { "requires_grad": self.requires_grad, - # TODO: Don't guard on this! - "metadata_cache": self._metadata_cache, "ragged_idx": self._ragged_idx, } inner_tensors = ["_values", "_offsets"] if self._lengths is not None: inner_tensors.append("_lengths") + if self._min_seqlen_tensor is not None: + inner_tensors.append("_min_seqlen_tensor") + if self._max_seqlen_tensor is not None: + inner_tensors.append("_max_seqlen_tensor") return inner_tensors, ctx @staticmethod def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): - # inner tensors: _values, _offsets, [_lengths] - assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3 + # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] + assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 values = inner_tensors["_values"] offsets = inner_tensors["_offsets"] lengths = inner_tensors.get("_lengths", None) + min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) + max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) + + metadata_cache = {} + if min_seqlen_tensor is not None: + metadata_cache["min_seqlen"] = min_seqlen_tensor + if max_seqlen_tensor is not None: + metadata_cache["max_seqlen"] = max_seqlen_tensor + ragged_idx = meta["ragged_idx"] # Note that we cannot simply check if is_fake(values) because @@ -211,7 +271,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): lengths=lengths, requires_grad=meta["requires_grad"], _ragged_idx=ragged_idx, - _metadata_cache=meta["metadata_cache"], + _metadata_cache=metadata_cache, ) @classmethod @@ -276,6 +336,15 @@ def forward( offsets: torch.Tensor, metadata_cache: Optional[Dict[str, Any]] = None, ): # type: ignore[override] + # maintain BC with this usages of this where the seqlens are stuffed + # directly into the metadata cache as non-Tensors / ints + if metadata_cache is not None: + min_seqlen = metadata_cache.get("min_seqlen", None) + max_seqlen = metadata_cache.get("max_seqlen", None) + if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): + metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) + if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): + metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) return NestedTensor( values.detach(), offsets=offsets, @@ -343,12 +412,12 @@ def jagged_from_list( ] ) - ret_nt = nested_view_from_values_offsets(values, offsets) - ret_nt._metadata_cache = { - # compute this now since it's easy - "max_seqlen": max(t.shape[0] for t in tensors), - "min_seqlen": min(t.shape[0] for t in tensors), - } + # compute this now since it's easy + min_seqlen = min([t.shape[0] for t in tensors]) + max_seqlen = max([t.shape[0] for t in tensors]) + ret_nt = nested_view_from_values_offsets( + values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen + ) return (ret_nt, offsets) # type: ignore[return-value] @@ -405,16 +474,19 @@ def jagged_from_tensor_and_lengths( if is_contiguous: ret_nt = nested_view_from_values_offsets( - values[offsets[0] : offsets[-1]], offsets - offsets[0] + values[offsets[0] : offsets[-1]], + offsets - offsets[0], + min_seqlen=min_seqlen, + max_seqlen=actual_max_seqlen, ) else: - ret_nt = nested_view_from_values_offsets_lengths(values, offsets, length_list) - - # populate metadata cache with computed seqlen extremes - ret_nt._metadata_cache = { - "max_seqlen": actual_max_seqlen, - "min_seqlen": min_seqlen, - } + ret_nt = nested_view_from_values_offsets_lengths( + values, + offsets, + length_list, + min_seqlen=min_seqlen, + max_seqlen=actual_max_seqlen, + ) return (ret_nt, offsets, None if is_contiguous else length_list) @@ -436,13 +508,45 @@ def _nt_view_dummy() -> torch.Tensor: return _dummy_instance -def nested_view_from_values_offsets(values, offsets, ragged_idx=1): +def nested_view_from_values_offsets( + values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None +): + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, offsets, _nt_view_dummy(), None, ragged_idx - ) + values, + offsets, + _nt_view_dummy(), + None, + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + ) # type: ignore[return-value] + +def nested_view_from_values_offsets_lengths( + values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None +): + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) -def nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, offsets, _nt_view_dummy(), lengths, ragged_idx - ) + values, + offsets, + _nt_view_dummy(), + lengths, + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + ) # type: ignore[return-value] diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6ec3ba538f9772..6f1c47dd694712 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1088,7 +1088,7 @@ def values_default(func, *args, **kwargs): @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, - "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?", + "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", ) def _nested_view_from_jagged_default(func, *args, **kwargs): _, new_kwargs = normalize_function( @@ -1101,8 +1101,21 @@ def _nested_view_from_jagged_default(func, *args, **kwargs): new_kwargs["lengths"], ) ragged_idx = new_kwargs["ragged_idx"] + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen - return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) + return NestedTensor( + values, + offsets, + lengths=lengths, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") @@ -1135,6 +1148,26 @@ def _nested_get_ragged_idx(func, *args, **kwargs): return inp._ragged_idx +@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") +def _nested_get_min_seqlen(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._metadata_cache.get("min_seqlen", None) + + +@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") +def _nested_get_max_seqlen(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + return inp._metadata_cache.get("max_seqlen", None) + + # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index b7c69c905e9a86..8f2eba4db3e463 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -15,7 +15,7 @@ ) from torch.nn.attention import SDPBackend -from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer +from .nested_tensor import NestedTensor log = logging.getLogger(__name__) @@ -125,7 +125,7 @@ def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( return False # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] - if param._min_seqlen == 0: + if param._get_min_seqlen() == 0: if debug: log.warning( "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", @@ -315,7 +315,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in if qkv.lengths() is None: # TODO: Explore performance impact of copying cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) - max_seqlen = qkv._max_seqlen + max_seqlen = qkv._get_max_seqlen() n_elem = qkv.values().shape[0] else: # TODO: Explore performance impact of copying @@ -323,7 +323,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) ) batch_size = qkv.size(0) - max_seqlen = qkv._max_seqlen + max_seqlen = qkv._get_max_seqlen() # TODO: Explore performance impact when compiling n_elem = int(cumulative_seqlen[-1].item()) return cumulative_seqlen, max_seqlen, n_elem @@ -364,7 +364,7 @@ def _view_as_dense( tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int ) -> torch.Tensor: if tensor.is_nested: - return buffer_from_jagged(tensor) + return tensor.values() return tensor.view(Nnz, num_heads, head_dim) @@ -567,8 +567,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._max_seqlen, - "_min_seqlen": q_t._min_seqlen, + "_max_seqlen": q_t._get_max_seqlen(), + "_min_seqlen": q_t._get_min_seqlen(), } return ( @@ -694,9 +694,14 @@ def jagged_scaled_dot_product_attention( False, scale=og_scale, ) + from torch.nested._internal.nested_tensor import nested_view_from_values_offsets + # Reshape output to convert nnz to batch_size and seq_len - attention = ViewNestedFromBuffer.apply( - attention.squeeze(0), output_nt_info["offsets"] + attention = nested_view_from_values_offsets( + attention.squeeze(0), + output_nt_info["offsets"], + min_seqlen=output_nt_info["_min_seqlen"], + max_seqlen=output_nt_info["_max_seqlen"], ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -732,9 +737,14 @@ def jagged_scaled_dot_product_attention( scale=scale, ) + from torch.nested._internal.nested_tensor import nested_view_from_values_offsets + # Reshape output to convert nnz to batch_size and seq_len - return ViewNestedFromBuffer.apply( - attention.squeeze(0), output_nt_info["offsets"] + return nested_view_from_values_offsets( + attention.squeeze(0), + output_nt_info["offsets"], + min_seqlen=output_nt_info["_min_seqlen"], + max_seqlen=output_nt_info["_max_seqlen"], ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output @@ -744,12 +754,19 @@ def jagged_scaled_dot_product_attention( d1 = query._size[1] d2 = value._size[-1] + min_seqlen_tensor = query._metadata_cache.get( + "min_seqlen", None + ) # type: ignore[attr-defined] + max_seqlen_tensor = query._metadata_cache.get( + "max_seqlen", None + ) # type: ignore[attr-defined] + # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] transpose = torch.transpose(jagged_layout_nt, 1, 2) - tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) + tensor_list = transpose.values().split(list(lengths), dim=0) strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) strided_nt = strided_nt.transpose(1, 2).contiguous() return strided_nt @@ -762,11 +779,28 @@ def get_strided_layout_nested_tensor(jagged_layout_nt): query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] + from torch.nested._internal.nested_tensor import ( + _load_val_from_tensor, + nested_view_from_values_offsets, + ) + # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) - attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) - attn_out = attn_out.transpose(1, 2) + attn_out = nested_view_from_values_offsets( + attn_out, + offsets, + min_seqlen=( + None + if min_seqlen_tensor is None + else _load_val_from_tensor(min_seqlen_tensor) + ), + max_seqlen=( + None + if max_seqlen_tensor is None + else _load_val_from_tensor(max_seqlen_tensor) + ), + ).transpose(1, 2) return attn_out else: From bfad0aee446b710c70fddc31fca34c8d4dda1bec Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 17 Jun 2024 16:26:08 +0000 Subject: [PATCH 080/171] [export] Preserve requires_grad for export inputs. (#128656) Summary: Today meta['val'] on placeholder nodes doesn't preserve the consistent requires_grad information with the original inputs. Seems there's no easy way to fix this directly at proxy tensor layer. This is useful for reexporting joint graph. Test Plan: test_preserve_requires_grad_placeholders Differential Revision: D58555651 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128656 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 20 ++++++++++++++++ test/export/test_serialize.py | 1 - torch/_export/__init__.py | 3 --- torch/export/_trace.py | 45 ++++++++++++++++++++++++++++++++--- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 528079bcfe1421..aff3cc1c96dc34 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -5382,6 +5382,26 @@ def forward(self, w, x, y, z): _disable_forced_specializations=True, ) + # TODO requires_grad doesn't seem to work with serialization. + @testing.expectedFailureSerDer + def test_preserve_requires_grad_placeholders(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = torch.nn.Parameter(torch.randn(3, 3)) + + def forward(self, x, y): + return self.p + x + y + + m = Module() + ep = export(m, (torch.randn(3, 3), torch.randn(3, 3, requires_grad=True))) + placeholders = [ + node for node in ep.graph_module.graph.nodes if node.op == "placeholder" + ] + self.assertTrue(placeholders[0].meta["val"].requires_grad) + self.assertFalse(placeholders[1].meta["val"].requires_grad) + self.assertTrue(placeholders[2].meta["val"].requires_grad) + def test_reshape_view_helper(self): # see: https://github.com/pytorch/pytorch/issues/126607 class Model(torch.nn.Module): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 012b35c910b5aa..1e0a9edf238727 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -253,7 +253,6 @@ def forward(self, x): return torch.split(x, 2) input = torch.arange(10.0).reshape(5, 2) - input.requires_grad = True exported_module = export(MyModule(), (input,)).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index d9a514232569d3..d257662a133246 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -73,9 +73,6 @@ from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.utils._sympy.value_ranges import ValueRangeError, ValueRanges -from .passes.add_runtime_assertions_for_constraints_pass import ( - _AddRuntimeAssertionsForInlineConstraintsPass, -) from .wrappers import _wrap_submodules log = logging.getLogger(__name__) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index d5f0851c87694e..8ec992f1058424 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -344,6 +344,41 @@ def _get_param_buffer_mapping( return param_buffer_table +def _preserve_requires_grad_pass( + gm: torch.fx.GraphModule, + sig: ExportGraphSignature, + fake_params_buffers: Dict[str, torch.Tensor], + constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], + flat_fake_args: List[Any], +): + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(sig.input_specs) == len(placeholders) + i = 0 + for node, spec in zip(placeholders, sig.input_specs): + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + assert spec.target is not None + node.meta["val"].requires_grad = fake_params_buffers[ + spec.target + ].requires_grad + elif spec.kind == InputKind.USER_INPUT: + fake_arg = flat_fake_args[i] + if isinstance(fake_arg, torch.Tensor): + node.meta["val"].requires_grad = fake_arg.requires_grad + i += 1 + elif spec.kind == InputKind.CONSTANT_TENSOR: + assert spec.target is not None + constant = constants[spec.target] + if isinstance(constant, torch.Tensor): + node.meta["val"].requires_grad = constant.requires_grad + elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): + continue + else: + raise AssertionError(spec.kind) + + def _remap_constants( orig_constant_attrs: ConstantAttrMap, graph_signature: ExportGraphSignature, @@ -632,7 +667,7 @@ def make_argument_spec(i, node) -> ArgumentSpec: # NOTE: aot_export adds symint metadata for placeholders with int values; # since these become specialized, we replace such metadata with the original values - flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) + flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) index = 0 total_non_user_inputs = ( len(graph_signature.parameters) @@ -642,7 +677,7 @@ def make_argument_spec(i, node) -> ArgumentSpec: for node in gm.graph.nodes: if node.op == "placeholder": if index >= total_non_user_inputs: - user_arg = flat_args[index - total_non_user_inputs] + user_arg = flat_fake_args[index - total_non_user_inputs] if not isinstance(user_arg, torch.Tensor): node.meta["val"] = user_arg index += 1 @@ -677,7 +712,7 @@ def make_argument_spec(i, node) -> ArgumentSpec: from torch._guards import detect_fake_mode - fake_mode = detect_fake_mode(flat_args) + fake_mode = detect_fake_mode(flat_fake_args) from torch._dynamo import config as _dynamo_config @@ -726,6 +761,10 @@ def make_argument_spec(i, node) -> ArgumentSpec: constants, ) + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args + ) + return ATenExportArtifact( gm, export_graph_signature, From dff6342a0b6c70f343fdc894928d10c73dd05ae5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Jun 2024 03:58:18 +0800 Subject: [PATCH 081/171] [BE][Easy] enable UFMT for `torch/nn/parallel` (#128596) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128596 Approved by: https://github.com/mikaylagawarecki --- .lintrunner.toml | 9 ----- torch/nn/attention/__init__.py | 4 +- torch/nn/attention/_utils.py | 1 + torch/nn/attention/bias.py | 7 ++-- torch/nn/backends/thnn.py | 1 + torch/nn/parallel/__init__.py | 20 +++++++--- torch/nn/parallel/_functions.py | 60 +++++++++++++++++------------ torch/nn/parallel/comm.py | 48 ++++++++++++++++------- torch/nn/parallel/data_parallel.py | 54 +++++++++++++++++--------- torch/nn/parallel/distributed.py | 12 +++--- torch/nn/parallel/parallel_apply.py | 48 ++++++++++++++++------- torch/nn/parallel/replicate.py | 51 ++++++++++++++++++------ torch/nn/parallel/scatter_gather.py | 36 ++++++++++++----- 13 files changed, 231 insertions(+), 120 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 3cd46419798b89..5f5b91424757e4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1645,8 +1645,6 @@ exclude_patterns = [ 'torch/nested/__init__.py', 'torch/nn/__init__.py', 'torch/nn/_reduction.py', - 'torch/nn/backends/__init__.py', - 'torch/nn/backends/thnn.py', 'torch/nn/common_types.py', 'torch/nn/cpp.py', 'torch/nn/functional.py', @@ -1694,13 +1692,6 @@ exclude_patterns = [ 'torch/nn/modules/transformer.py', 'torch/nn/modules/upsampling.py', 'torch/nn/modules/utils.py', - 'torch/nn/parallel/__init__.py', - 'torch/nn/parallel/_functions.py', - 'torch/nn/parallel/comm.py', - 'torch/nn/parallel/data_parallel.py', - 'torch/nn/parallel/parallel_apply.py', - 'torch/nn/parallel/replicate.py', - 'torch/nn/parallel/scatter_gather.py', 'torch/nn/parameter.py', 'torch/nn/qat/__init__.py', 'torch/nn/qat/dynamic/__init__.py', diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 6bf1ffb68e6960..48e49c884398f6 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -4,6 +4,7 @@ from typing import List, Union from warnings import warn +from torch._C import _SDPBackend as SDPBackend from torch.backends.cuda import ( can_use_efficient_attention, can_use_flash_attention, @@ -18,6 +19,7 @@ SDPAParams, ) + __all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"] # Note: [SDPA warnings] @@ -30,8 +32,6 @@ WARN_FOR_UNFUSED_KERNELS = False -from torch._C import _SDPBackend as SDPBackend - # Hacks for Sphinx documentation: # https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class SDPBackend = SDPBackend diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 9785f74c668362..25b5a0774173a1 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -5,6 +5,7 @@ import torch + __all__: List[str] = [] diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 773ed38f82e867..8181234b5e938a 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -5,6 +5,7 @@ from warnings import warn import torch +import torch.nn.functional as F from torch.backends.cuda import ( can_use_efficient_attention, can_use_flash_attention, @@ -17,7 +18,7 @@ _postprocess_flash_output, _validate_sdpa_input, ) -from torch.nn.functional import scaled_dot_product_attention + __all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"] @@ -203,7 +204,7 @@ def _dispatch( attn_mask.seq_len_q == attn_mask.seq_len_kv or attn_mask.variant == CausalVariant.UPPER_LEFT ): - return scaled_dot_product_attention( + return F.scaled_dot_product_attention( query, key, value, @@ -255,7 +256,7 @@ def _dispatch( else: _raise_kernel_warnings(sdpa_params) # We cant use efficient attention the only support for lower right is via materialization - return scaled_dot_product_attention( + return F.scaled_dot_product_attention( query, key, value, diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 3cb0f3ff57e270..8564153ece2331 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # this is for historical pickle deserialization, it is not used otherwise + def _get_thnn_function_backend(): pass diff --git a/torch/nn/parallel/__init__.py b/torch/nn/parallel/__init__.py index 8f08e5099d8b5d..fe898f8c896743 100644 --- a/torch/nn/parallel/__init__.py +++ b/torch/nn/parallel/__init__.py @@ -1,14 +1,22 @@ # mypy: allow-untyped-defs from typing_extensions import deprecated +from .data_parallel import data_parallel, DataParallel +from .distributed import DistributedDataParallel from .parallel_apply import parallel_apply from .replicate import replicate -from .data_parallel import DataParallel, data_parallel from .scatter_gather import gather, scatter -from .distributed import DistributedDataParallel -__all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel', - 'DataParallel', 'DistributedDataParallel'] + +__all__ = [ + "replicate", + "scatter", + "parallel_apply", + "gather", + "data_parallel", + "DataParallel", + "DistributedDataParallel", +] @deprecated( @@ -16,5 +24,5 @@ "please use `torch.nn.parallel.DistributedDataParallel` instead.", category=FutureWarning, ) -def DistributedDataParallelCPU(*args, **kwargs): - return DistributedDataParallel(*args, **kwargs) +class DistributedDataParallelCPU(DistributedDataParallel): + pass diff --git a/torch/nn/parallel/_functions.py b/torch/nn/parallel/_functions.py index d987ed2bc42746..fa04bbc353ddc7 100644 --- a/torch/nn/parallel/_functions.py +++ b/torch/nn/parallel/_functions.py @@ -1,19 +1,19 @@ import warnings +from typing import List, Optional import torch -from . import comm -from torch.autograd import Function from torch._utils import _get_device_index -from typing import List, Optional +from torch.autograd import Function + +from . import comm class Broadcast(Function): - @staticmethod def forward(ctx, target_gpus, *inputs): - assert all(i.device.type != 'cpu' for i in inputs), ( - 'Broadcast function not implemented for CPU tensors' - ) + assert all( + i.device.type != "cpu" for i in inputs + ), "Broadcast function not implemented for CPU tensors" target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.target_gpus = target_gpus if len(inputs) == 0: @@ -31,33 +31,37 @@ def forward(ctx, target_gpus, *inputs): @staticmethod def backward(ctx, *grad_outputs): - return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs) + return (None,) + ReduceAddCoalesced.apply( + ctx.input_device, ctx.num_inputs, *grad_outputs + ) class ReduceAddCoalesced(Function): - @staticmethod def forward(ctx, destination, num_inputs, *grads): - ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)] + ctx.target_gpus = [ + grads[i].get_device() for i in range(0, len(grads), num_inputs) + ] - grads_ = [grads[i:i + num_inputs] - for i in range(0, len(grads), num_inputs)] + grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] return comm.reduce_add_coalesced(grads_, destination) @staticmethod def backward(ctx, *grad_outputs): - return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs) + return ( + None, + None, + ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) class Gather(Function): - @staticmethod def forward(ctx, target_device, dim, *inputs): - assert all(i.device.type != 'cpu' for i in inputs), ( - 'Gather function not implemented for CPU tensors' - ) - if (target_device == 'cpu'): - ctx.target_device = 'cpu' + assert all( + i.device.type != "cpu" for i in inputs + ), "Gather function not implemented for CPU tensors" + if target_device == "cpu": + ctx.target_device = "cpu" else: target_device = _get_device_index(target_device, True) ctx.target_device = target_device @@ -65,9 +69,11 @@ def forward(ctx, target_device, dim, *inputs): ctx.input_gpus = tuple(i.get_device() for i in inputs) if all(t.dim() == 0 for t in inputs) and dim == 0: inputs = tuple(t.view(1) for t in inputs) - warnings.warn('Was asked to gather along dimension 0, but all ' - 'input tensors were scalars; will instead unsqueeze ' - 'and return a vector.') + warnings.warn( + "Was asked to gather along dimension 0, but all " + "input tensors were scalars; will instead unsqueeze " + "and return a vector." + ) ctx.unsqueezed_scalar = True else: ctx.unsqueezed_scalar = False @@ -76,14 +82,15 @@ def forward(ctx, target_device, dim, *inputs): @staticmethod def backward(ctx, grad_output): - scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output) + scattered_grads = Scatter.apply( + ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output + ) if ctx.unsqueezed_scalar: scattered_grads = tuple(g[0] for g in scattered_grads) return (None, None) + scattered_grads class Scatter(Function): - @staticmethod def forward(ctx, target_gpus, chunk_sizes, dim, input): target_gpus = [_get_device_index(x, True) for x in target_gpus] @@ -92,7 +99,9 @@ def forward(ctx, target_gpus, chunk_sizes, dim, input): streams = None if torch.cuda.is_available() and ctx.input_device == -1: # Perform CPU to GPU copies in a background stream - streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus] + streams = [ + _get_stream(torch.device("cuda", device)) for device in target_gpus + ] outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) # Synchronize with the copy stream if streams is not None: @@ -111,6 +120,7 @@ def backward(ctx, *grad_output): # background streams used for copying _streams: Optional[List[Optional[torch.Stream]]] = None + def _get_stream(device: torch.device): """Get a background stream for copying between CPU and target device.""" global _streams diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index b907de4004b14f..4150ce1a615b33 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -1,10 +1,18 @@ # mypy: allow-untyped-defs import warnings +from typing import List + import torch +from torch._utils import ( + _flatten_dense_tensors, + _get_device_index, + _handle_complex, + _reorder_tensors_as, + _take_tensors, + _unflatten_dense_tensors, +) from torch.cuda import nccl -from torch._utils import _take_tensors, _flatten_dense_tensors, \ - _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index, _handle_complex -from typing import List + def broadcast(tensor, devices=None, *, out=None): r"""Broadcasts a tensor to specified GPU devices. @@ -30,7 +38,8 @@ def broadcast(tensor, devices=None, *, out=None): tensor = _handle_complex(tensor) if not ((devices is None) ^ (out is None)): raise RuntimeError( - f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}") + f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}" + ) if devices is not None: devices = [_get_device_index(d) for d in devices] return torch._C._broadcast(tensor, devices) @@ -81,11 +90,15 @@ def reduce_add(inputs, destination=None): if inp.get_device() == destination: root_index = i if inp.size() != input_size: - got = 'x'.join(str(x) for x in inp.size()) - expected = 'x'.join(str(x) for x in input_size) - raise ValueError(f"input {i} has invalid size: got {got}, but expected {expected}") + got = "x".join(str(x) for x in inp.size()) + expected = "x".join(str(x) for x in input_size) + raise ValueError( + f"input {i} has invalid size: got {got}, but expected {expected}" + ) if root_index is None: - raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors") + raise RuntimeError( + "reduce_add expects destination to be on the same GPU with one of the tensors" + ) if len(inputs) == 1: return inputs[0] @@ -97,7 +110,9 @@ def reduce_add(inputs, destination=None): destination_device = torch.device(inputs[root_index].device.type, destination) nonroot = [t for i, t in enumerate(inputs) if i != root_index] # make a new tensor w/o clone - result = inputs[root_index] + nonroot[0].to(device=destination_device, non_blocking=True) + result = inputs[root_index] + nonroot[0].to( + device=destination_device, non_blocking=True + ) for other in nonroot[1:]: result.add_(other.to(device=destination_device, non_blocking=True)) return result @@ -138,7 +153,9 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors] # now the dense ones, which have consistent sizes for chunks in zip(*itrs): - flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks] # (num_gpus,) + flat_tensors = [ + _flatten_dense_tensors(chunk) for chunk in chunks + ] # (num_gpus,) flat_result = reduce_add(flat_tensors, destination) for t in _unflatten_dense_tensors(flat_result, chunks[0]): # The unflattened tensors do not share storage, and we don't expose @@ -189,10 +206,12 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= else: if devices is not None: raise RuntimeError( - f"'devices' must not be specified when 'out' is specified, but got devices={devices}") + f"'devices' must not be specified when 'out' is specified, but got devices={devices}" + ) if chunk_sizes is not None: raise RuntimeError( - f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}") + f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}" + ) return tuple(torch._C._scatter_out(tensor, out, dim, streams)) @@ -226,7 +245,7 @@ def gather(tensors, dim=0, destination=None, *, out=None): if out is None: if destination == -1: warnings.warn( - 'Using -1 to represent CPU tensor is deprecated. Please use a ' + "Using -1 to represent CPU tensor is deprecated. Please use a " 'device object or string instead, e.g., "cpu".', FutureWarning, stacklevel=2, @@ -236,5 +255,6 @@ def gather(tensors, dim=0, destination=None, *, out=None): else: if destination is not None: raise RuntimeError( - f"'destination' must not be specified when 'out' is specified, but got destination={destination}") + f"'destination' must not be specified when 'out' is specified, but got destination={destination}" + ) return torch._C._gather_out(tensors, out, dim) diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 3980706a932a19..524e9144d33867 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -1,21 +1,25 @@ # mypy: allow-untyped-defs import operator -import torch import warnings from itertools import chain from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union -from ..modules import Module -from .scatter_gather import scatter_kwargs, gather -from .replicate import replicate -from .parallel_apply import parallel_apply + +import torch from torch._utils import ( _get_all_device_indices, _get_available_device_type, _get_device_index, - _get_devices_properties + _get_devices_properties, ) -__all__ = ['DataParallel', 'data_parallel'] +from ..modules import Module +from .parallel_apply import parallel_apply +from .replicate import replicate +from .scatter_gather import gather, scatter_kwargs + + +__all__ = ["DataParallel", "data_parallel"] + def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: imbalance_warn = """ @@ -31,7 +35,9 @@ def warn_imbalance(get_prop): min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) if min_val / max_val < 0.75: - warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) + warnings.warn( + imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]) + ) return True return False @@ -169,9 +175,11 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - f"on device {self.src_device_obj} (device_ids[0]) but found one of " - f"them on device: {t.device}") + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {self.src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) # for forward function without any inputs, empty list and dict will be created @@ -182,11 +190,13 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: if len(self.device_ids) == 1: return self.module(*inputs[0], **module_kwargs[0]) - replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) outputs = self.parallel_apply(replicas, inputs, module_kwargs) return self.gather(outputs, self.output_device) - def replicate(self, module: T, device_ids: Sequence[Union[int, torch.device]]) -> List[T]: + def replicate( + self, module: T, device_ids: Sequence[Union[int, torch.device]] + ) -> List[T]: return replicate(module, device_ids, not torch.is_grad_enabled()) def scatter( @@ -197,8 +207,12 @@ def scatter( ) -> Any: return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) - def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]: - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + def parallel_apply( + self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any + ) -> List[Any]: + return parallel_apply( + replicas, inputs, kwargs, self.device_ids[: len(replicas)] + ) def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: return gather(outputs, output_device, dim=self.dim) @@ -249,9 +263,11 @@ def data_parallel( for t in chain(module.parameters(), module.buffers()): if t.device != src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - f"on device {src_device_obj} (device_ids[0]) but found one of " - f"them on device: {t.device}") + raise RuntimeError( + "module must have its parameters and buffers " + f"on device {src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}" + ) inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) # for module without any inputs, empty list and dict will be created @@ -264,7 +280,7 @@ def data_parallel( if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0]) - used_device_ids = device_ids[:len(inputs)] + used_device_ids = device_ids[: len(inputs)] replicas = replicate(module, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) return gather(outputs, output_device, dim) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 80ed52d9a0b67e..3fdcbe70c9ac56 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -16,10 +16,15 @@ import torch import torch.distributed as dist +from torch._utils import _get_device_index from torch.autograd import Function, Variable from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.utils._pytree import tree_flatten, tree_unflatten +from ..modules import Module +from .scatter_gather import gather, scatter_kwargs + + RPC_AVAILABLE = False if dist.is_available(): from torch.distributed.distributed_c10d import ( @@ -35,15 +40,10 @@ _to_kwargs, _verify_param_shape_across_processes, ) -if torch.distributed.rpc.is_available(): +if dist.rpc.is_available(): RPC_AVAILABLE = True from torch.distributed.rpc import RRef -from torch._utils import _get_device_index - -from ..modules import Module -from .scatter_gather import gather, scatter_kwargs # noqa: F401 - if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index 6a90f897fa8ada..56ab7db0f4e109 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py @@ -1,14 +1,20 @@ import threading +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union + import torch -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from ..modules import Module +from torch._utils import ExceptionWrapper from torch.cuda._utils import _get_device_index from torch.cuda.amp import autocast -from torch._utils import ExceptionWrapper -__all__ = ['get_a_var', 'parallel_apply'] +from ..modules import Module -def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]: + +__all__ = ["get_a_var", "parallel_apply"] + + +def get_a_var( + obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]], +) -> Optional[torch.Tensor]: if isinstance(obj, torch.Tensor): return obj @@ -22,6 +28,7 @@ def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any return result return None + def parallel_apply( modules: Sequence[Module], inputs: Sequence[Any], @@ -40,7 +47,9 @@ def parallel_apply( element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ - assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' + assert len(modules) == len( + inputs + ), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}" if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: @@ -53,7 +62,10 @@ def parallel_apply( streams = [torch.cuda.current_stream(x) for x in devices] lock = threading.Lock() results = {} - grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() + grad_enabled, autocast_enabled = ( + torch.is_grad_enabled(), + torch.is_autocast_enabled(), + ) def _worker( i: int, @@ -70,13 +82,16 @@ def _worker( with lock: results[i] = ExceptionWrapper( where=f"in replica {i}, no device was provided and no tensor input was found; " - "device cannot be resolved") + "device cannot be resolved" + ) return device = t.get_device() if stream is None: stream = torch.cuda.current_stream(device) try: - with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled): + with torch.cuda.device(device), torch.cuda.stream(stream), autocast( + enabled=autocast_enabled + ): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) @@ -86,13 +101,18 @@ def _worker( except Exception: with lock: results[i] = ExceptionWrapper( - where=f"in replica {i} on device {device}") + where=f"in replica {i} on device {device}" + ) if len(modules) > 1: - threads = [threading.Thread(target=_worker, - args=(i, module, input, kwargs, device, stream)) - for i, (module, input, kwargs, device, stream) in - enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] + threads = [ + threading.Thread( + target=_worker, args=(i, module, input, kwargs, device, stream) + ) + for i, (module, input, kwargs, device, stream) in enumerate( + zip(modules, inputs, kwargs_tup, devices, streams) + ) + ] for thread in threads: thread.start() diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index fbe12d23ee8bec..72836c18de5397 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -1,34 +1,53 @@ +from collections import OrderedDict +from typing import ( + cast, + Dict, + Iterator, + List, + Optional, + Sequence, + Set, + TYPE_CHECKING, + TypeVar, + Union, +) + import torch +from torch._utils import _get_device_index + from ..modules import Module from . import comm -from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast -from torch._utils import _get_device_index -from collections import OrderedDict if TYPE_CHECKING: from torch.jit import ScriptModule from torch.jit._state import EnabledProxy -__all__ = ['replicate'] + +__all__ = ["replicate"] + def _is_script_module(module: Module) -> bool: import torch.jit + return isinstance(module, torch.jit.ScriptModule) def _is_script_method(module: Module) -> bool: import torch.jit + return isinstance(module, torch._C.ScriptMethod) def _init_script_module() -> "ScriptModule": import torch.jit + return torch.jit.ScriptModule() def _is_jit_enabled() -> "EnabledProxy": import torch.jit._state + return torch.jit._state._enabled @@ -40,7 +59,6 @@ def _is_jit_enabled() -> "EnabledProxy": # currently a module cannot be replicated properly if the descendants of # any ScriptModule contains python module (type 1 above) def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool: - # module.modules() contains module itself as the first element def descendant_modules(module: Module) -> Iterator[Module]: gen = module.modules() @@ -56,8 +74,9 @@ def descendant_modules(module: Module) -> Iterator[Module]: memo.add(module) if _is_script_module(module): memo.update(descendant_modules(module)) - return all(_is_script_module(descendant) for - descendant in descendant_modules(module)) + return all( + _is_script_module(descendant) for descendant in descendant_modules(module) + ) for child in module.children(): # since any unreplicatable module will cause the check to return @@ -69,20 +88,24 @@ def descendant_modules(module: Module) -> Iterator[Module]: return True + def _broadcast_coalesced_reshape( tensors: Sequence[torch.Tensor], devices: Sequence[Union[int, torch.device]], detach: bool = False, ) -> List[List[torch.Tensor]]: from ._functions import Broadcast + if detach: return comm.broadcast_coalesced(tensors, devices) else: # Use the autograd function to broadcast if not detach if len(tensors) > 0: tensor_copies = Broadcast.apply(devices, *tensors) - return [tensor_copies[i:i + len(tensors)] - for i in range(0, len(tensor_copies), len(tensors))] + return [ + tensor_copies[i : i + len(tensors)] + for i in range(0, len(tensor_copies), len(tensors)) + ] else: return [] @@ -96,8 +119,10 @@ def replicate( detach: bool = False, ) -> List[T]: if not _replicatable_module(network): - raise RuntimeError("Cannot replicate network where python modules are " - "childrens of ScriptModule") + raise RuntimeError( + "Cannot replicate network where python modules are " + "childrens of ScriptModule" + ) if not devices: return [] @@ -122,7 +147,9 @@ def replicate( buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) - buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) + buffer_copies_not_rg = _broadcast_coalesced_reshape( + buffers_not_rg, devices, detach=True + ) modules = list(network.modules()) module_copies: List[List[Module]] = [[] for _ in devices] diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 73e753760e72b6..690f81f227855b 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -1,10 +1,13 @@ # mypy: allow-untyped-defs -import torch -from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload +from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union from typing_extensions import deprecated -from ._functions import Scatter, Gather -__all__ = ['scatter', 'scatter_kwargs', 'gather'] +import torch + +from ._functions import Gather, Scatter + + +__all__ = ["scatter", "scatter_kwargs", "gather"] @deprecated( @@ -15,6 +18,7 @@ def is_namedtuple(obj: Any) -> bool: # Check if type was created from collections.namedtuple or a typing.NamedTuple. return _is_namedtuple(obj) + def _is_namedtuple(obj: Any) -> bool: # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( @@ -24,6 +28,7 @@ def _is_namedtuple(obj: Any) -> bool: T = TypeVar("T", dict, list, tuple) + # For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. @overload def scatter( @@ -33,15 +38,22 @@ def scatter( ) -> Tuple[torch.Tensor, ...]: ... + @overload -def scatter(inputs: T, target_gpus: Sequence[Union[int, torch.device]], dim: int = ...) -> List[T]: +def scatter( + inputs: T, + target_gpus: Sequence[Union[int, torch.device]], + dim: int = ..., +) -> List[T]: ... + def scatter(inputs, target_gpus, dim=0): r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ + def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) @@ -77,9 +89,13 @@ def scatter_kwargs( scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] if len(scattered_inputs) < len(scattered_kwargs): - scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs))) + scattered_inputs.extend( + () for _ in range(len(scattered_kwargs) - len(scattered_inputs)) + ) elif len(scattered_kwargs) < len(inputs): - scattered_kwargs.extend({} for _ in range(len(scattered_inputs) - len(scattered_kwargs))) + scattered_kwargs.extend( + {} for _ in range(len(scattered_inputs) - len(scattered_kwargs)) + ) return tuple(scattered_inputs), tuple(scattered_kwargs) @@ -88,6 +104,7 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) Use 'cpu' for CPU to avoid a deprecation warning. """ + def gather_map(outputs): out = outputs[0] if isinstance(out, torch.Tensor): @@ -96,9 +113,8 @@ def gather_map(outputs): return None if isinstance(out, dict): if not all(len(out) == len(d) for d in outputs): - raise ValueError('All dicts must have the same number of keys') - return type(out)((k, gather_map([d[k] for d in outputs])) - for k in out) + raise ValueError("All dicts must have the same number of keys") + return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): return type(out)._make(map(gather_map, zip(*outputs))) return type(out)(map(gather_map, zip(*outputs))) From 95ac2d648279ebc73feccf6d8eccafa4b2759de8 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Jun 2024 03:58:19 +0800 Subject: [PATCH 082/171] [BE] enable UFMT for `torch/nn/modules` (#128594) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128594 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #128596 --- .lintrunner.toml | 26 - torch/nn/modules/__init__.py | 389 +++++++++++--- torch/nn/modules/_functions.py | 102 ++-- torch/nn/modules/activation.py | 493 ++++++++++------- torch/nn/modules/adaptive.py | 103 ++-- torch/nn/modules/batchnorm.py | 111 ++-- torch/nn/modules/channelshuffle.py | 12 +- torch/nn/modules/container.py | 246 +++++---- torch/nn/modules/conv.py | 637 +++++++++++++++------- torch/nn/modules/distance.py | 16 +- torch/nn/modules/dropout.py | 23 +- torch/nn/modules/flatten.py | 47 +- torch/nn/modules/fold.py | 46 +- torch/nn/modules/instancenorm.py | 93 ++-- torch/nn/modules/lazy.py | 57 +- torch/nn/modules/linear.py | 87 ++- torch/nn/modules/loss.py | 386 ++++++++++---- torch/nn/modules/module.py | 745 +++++++++++++++++--------- torch/nn/modules/normalization.py | 129 +++-- torch/nn/modules/padding.py | 83 +-- torch/nn/modules/pixelshuffle.py | 16 +- torch/nn/modules/pooling.py | 386 ++++++++++---- torch/nn/modules/rnn.py | 823 ++++++++++++++++++++--------- torch/nn/modules/sparse.py | 216 +++++--- torch/nn/modules/transformer.py | 545 +++++++++++++------ torch/nn/modules/upsampling.py | 68 ++- torch/nn/modules/utils.py | 15 +- 27 files changed, 4049 insertions(+), 1851 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 5f5b91424757e4..2093f258849019 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1666,32 +1666,6 @@ exclude_patterns = [ 'torch/nn/intrinsic/quantized/modules/bn_relu.py', 'torch/nn/intrinsic/quantized/modules/conv_relu.py', 'torch/nn/intrinsic/quantized/modules/linear_relu.py', - 'torch/nn/modules/__init__.py', - 'torch/nn/modules/_functions.py', - 'torch/nn/modules/activation.py', - 'torch/nn/modules/adaptive.py', - 'torch/nn/modules/batchnorm.py', - 'torch/nn/modules/channelshuffle.py', - 'torch/nn/modules/container.py', - 'torch/nn/modules/conv.py', - 'torch/nn/modules/distance.py', - 'torch/nn/modules/dropout.py', - 'torch/nn/modules/flatten.py', - 'torch/nn/modules/fold.py', - 'torch/nn/modules/instancenorm.py', - 'torch/nn/modules/lazy.py', - 'torch/nn/modules/linear.py', - 'torch/nn/modules/loss.py', - 'torch/nn/modules/module.py', - 'torch/nn/modules/normalization.py', - 'torch/nn/modules/padding.py', - 'torch/nn/modules/pixelshuffle.py', - 'torch/nn/modules/pooling.py', - 'torch/nn/modules/rnn.py', - 'torch/nn/modules/sparse.py', - 'torch/nn/modules/transformer.py', - 'torch/nn/modules/upsampling.py', - 'torch/nn/modules/utils.py', 'torch/nn/parameter.py', 'torch/nn/qat/__init__.py', 'torch/nn/qat/dynamic/__init__.py', diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index 403d0d547e2bb5..af846e4f1d8d5e 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -1,68 +1,331 @@ -from .module import Module -from .linear import Identity, Linear, Bilinear, LazyLinear -from .conv import Conv1d, Conv2d, Conv3d, \ - ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, \ - LazyConv1d, LazyConv2d, LazyConv3d, LazyConvTranspose1d, LazyConvTranspose2d, LazyConvTranspose3d -from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \ - Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \ - Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \ - Hardsigmoid, Hardswish, SiLU, Mish -from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \ - CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \ - MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, HuberLoss, \ - SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss, GaussianNLLLoss -from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict -from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ - MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, LPPool3d, \ - AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d -from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \ - LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d -from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d, \ - LazyInstanceNorm1d, LazyInstanceNorm2d, LazyInstanceNorm3d -from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm, RMSNorm -from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout -from .padding import ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d, ReplicationPad2d, \ - ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d, \ - CircularPad1d, CircularPad2d, CircularPad3d -from .sparse import Embedding, EmbeddingBag -from .rnn import RNNBase, RNN, LSTM, GRU, \ - RNNCellBase, RNNCell, LSTMCell, GRUCell -from .pixelshuffle import PixelShuffle, PixelUnshuffle -from .upsampling import UpsamplingNearest2d, UpsamplingBilinear2d, Upsample -from .distance import PairwiseDistance, CosineSimilarity -from .fold import Fold, Unfold +from .activation import ( + CELU, + ELU, + GELU, + GLU, + Hardshrink, + Hardsigmoid, + Hardswish, + Hardtanh, + LeakyReLU, + LogSigmoid, + LogSoftmax, + Mish, + MultiheadAttention, + PReLU, + ReLU, + ReLU6, + RReLU, + SELU, + Sigmoid, + SiLU, + Softmax, + Softmax2d, + Softmin, + Softplus, + Softshrink, + Softsign, + Tanh, + Tanhshrink, + Threshold, +) from .adaptive import AdaptiveLogSoftmaxWithLoss -from .transformer import TransformerEncoder, TransformerDecoder, \ - TransformerEncoderLayer, TransformerDecoderLayer, Transformer -from .flatten import Flatten, Unflatten +from .batchnorm import ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + LazyBatchNorm1d, + LazyBatchNorm2d, + LazyBatchNorm3d, + SyncBatchNorm, +) from .channelshuffle import ChannelShuffle +from .container import ( + Container, + ModuleDict, + ModuleList, + ParameterDict, + ParameterList, + Sequential, +) +from .conv import ( + Conv1d, + Conv2d, + Conv3d, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + LazyConv1d, + LazyConv2d, + LazyConv3d, + LazyConvTranspose1d, + LazyConvTranspose2d, + LazyConvTranspose3d, +) +from .distance import CosineSimilarity, PairwiseDistance +from .dropout import ( + AlphaDropout, + Dropout, + Dropout1d, + Dropout2d, + Dropout3d, + FeatureAlphaDropout, +) +from .flatten import Flatten, Unflatten +from .fold import Fold, Unfold +from .instancenorm import ( + InstanceNorm1d, + InstanceNorm2d, + InstanceNorm3d, + LazyInstanceNorm1d, + LazyInstanceNorm2d, + LazyInstanceNorm3d, +) +from .linear import Bilinear, Identity, LazyLinear, Linear +from .loss import ( + BCELoss, + BCEWithLogitsLoss, + CosineEmbeddingLoss, + CrossEntropyLoss, + CTCLoss, + GaussianNLLLoss, + HingeEmbeddingLoss, + HuberLoss, + KLDivLoss, + L1Loss, + MarginRankingLoss, + MSELoss, + MultiLabelMarginLoss, + MultiLabelSoftMarginLoss, + MultiMarginLoss, + NLLLoss, + NLLLoss2d, + PoissonNLLLoss, + SmoothL1Loss, + SoftMarginLoss, + TripletMarginLoss, + TripletMarginWithDistanceLoss, +) +from .module import Module +from .normalization import ( + CrossMapLRN2d, + GroupNorm, + LayerNorm, + LocalResponseNorm, + RMSNorm, +) +from .padding import ( + CircularPad1d, + CircularPad2d, + CircularPad3d, + ConstantPad1d, + ConstantPad2d, + ConstantPad3d, + ReflectionPad1d, + ReflectionPad2d, + ReflectionPad3d, + ReplicationPad1d, + ReplicationPad2d, + ReplicationPad3d, + ZeroPad1d, + ZeroPad2d, + ZeroPad3d, +) +from .pixelshuffle import PixelShuffle, PixelUnshuffle +from .pooling import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + AdaptiveMaxPool1d, + AdaptiveMaxPool2d, + AdaptiveMaxPool3d, + AvgPool1d, + AvgPool2d, + AvgPool3d, + FractionalMaxPool2d, + FractionalMaxPool3d, + LPPool1d, + LPPool2d, + LPPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, + MaxUnpool1d, + MaxUnpool2d, + MaxUnpool3d, +) +from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase +from .sparse import Embedding, EmbeddingBag +from .transformer import ( + Transformer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d + __all__ = [ - 'Module', 'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', - 'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6', - 'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink', - 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin', - 'Tanhshrink', 'RReLU', 'L1Loss', 'NLLLoss', 'KLDivLoss', 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', - 'NLLLoss2d', 'PoissonNLLLoss', 'CosineEmbeddingLoss', 'CTCLoss', 'HingeEmbeddingLoss', 'MarginRankingLoss', - 'MultiLabelMarginLoss', 'MultiLabelSoftMarginLoss', 'MultiMarginLoss', 'SmoothL1Loss', 'GaussianNLLLoss', - 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'Container', 'Sequential', 'ModuleList', 'ModuleDict', - 'ParameterList', 'ParameterDict', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d', - 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d", - 'LPPool1d', 'LPPool2d', 'LPPool3d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', - 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'RMSNorm', 'SyncBatchNorm', - 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout', - 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d', - 'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', - 'LSTMCell', 'GRUCell', 'PixelShuffle', 'PixelUnshuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d', - 'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d', - 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d', - 'ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', - 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', - 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', - 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', - 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', - 'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d', - 'LazyInstanceNorm1d', 'LazyInstanceNorm2d', 'LazyInstanceNorm3d', - 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'Mish', 'TripletMarginWithDistanceLoss', 'ChannelShuffle', - 'CircularPad1d', 'CircularPad2d', 'CircularPad3d' + "Module", + "Identity", + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "Threshold", + "ReLU", + "Hardtanh", + "ReLU6", + "Sigmoid", + "Tanh", + "Softmax", + "Softmax2d", + "LogSoftmax", + "ELU", + "SELU", + "CELU", + "GLU", + "GELU", + "Hardshrink", + "LeakyReLU", + "LogSigmoid", + "Softplus", + "Softshrink", + "MultiheadAttention", + "PReLU", + "Softsign", + "Softmin", + "Tanhshrink", + "RReLU", + "L1Loss", + "NLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "CosineEmbeddingLoss", + "CTCLoss", + "HingeEmbeddingLoss", + "MarginRankingLoss", + "MultiLabelMarginLoss", + "MultiLabelSoftMarginLoss", + "MultiMarginLoss", + "SmoothL1Loss", + "GaussianNLLLoss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "LocalResponseNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LayerNorm", + "GroupNorm", + "RMSNorm", + "SyncBatchNorm", + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad2d", + "ReplicationPad1d", + "ReplicationPad3d", + "CrossMapLRN2d", + "Embedding", + "EmbeddingBag", + "RNNBase", + "RNN", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", + "PixelShuffle", + "PixelUnshuffle", + "Upsample", + "UpsamplingNearest2d", + "UpsamplingBilinear2d", + "PairwiseDistance", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", + "TripletMarginLoss", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "Bilinear", + "CosineSimilarity", + "Unfold", + "Fold", + "AdaptiveLogSoftmaxWithLoss", + "TransformerEncoder", + "TransformerDecoder", + "TransformerEncoderLayer", + "TransformerDecoderLayer", + "Transformer", + "LazyLinear", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", + "LazyBatchNorm1d", + "LazyBatchNorm2d", + "LazyBatchNorm3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", + "Flatten", + "Unflatten", + "Hardsigmoid", + "Hardswish", + "SiLU", + "Mish", + "TripletMarginWithDistanceLoss", + "ChannelShuffle", + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", ] diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 0e19faa99e5c45..847afcef4da2ee 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -1,16 +1,26 @@ # mypy: allow-untyped-defs import torch import torch.distributed as dist - from torch.autograd.function import Function -class SyncBatchNorm(Function): +class SyncBatchNorm(Function): @staticmethod - def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): + def forward( + self, + input, + weight, + bias, + running_mean, + running_var, + eps, + momentum, + process_group, + world_size, + ): if not ( - input.is_contiguous(memory_format=torch.channels_last) or - input.is_contiguous(memory_format=torch.channels_last_3d) + input.is_contiguous(memory_format=torch.channels_last) + or input.is_contiguous(memory_format=torch.channels_last_3d) ): input = input.contiguous() if weight is not None: @@ -18,7 +28,9 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, size = int(input.numel() // input.size(1)) if size == 1 and world_size < 2: - raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}') + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) num_channels = input.shape[1] if input.numel() > 0: @@ -29,7 +41,7 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, (1,), input.numel() // input.size(1), dtype=mean.dtype, - device=mean.device + device=mean.device, ) # C, C, 1 -> (2C + 1) @@ -40,9 +52,7 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, # & invstd, but they still needs to participate the all_gather # collective communication to unblock other peer processes. combined = torch.zeros( - 2 * num_channels + 1, - dtype=input.dtype, - device=input.device + 2 * num_channels + 1, dtype=input.dtype, device=input.device ) # Use allgather instead of allreduce because count could be different across @@ -54,19 +64,21 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, if process_group._get_backend_name() != "gloo": # world_size * (2C + 1) combined_size = combined.numel() - combined_flat = torch.empty(1, - combined_size * world_size, - dtype=combined.dtype, - device=combined.device) - dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False) + combined_flat = torch.empty( + 1, + combined_size * world_size, + dtype=combined.dtype, + device=combined.device, + ) + dist.all_gather_into_tensor( + combined_flat, combined, process_group, async_op=False + ) combined = torch.reshape(combined_flat, (world_size, combined_size)) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) else: # world_size * (2C + 1) - combined_list = [ - torch.empty_like(combined) for _ in range(world_size) - ] + combined_list = [torch.empty_like(combined) for _ in range(world_size)] dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 @@ -113,8 +125,8 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, @staticmethod def backward(self, grad_output): if not ( - grad_output.is_contiguous(memory_format=torch.channels_last) or - grad_output.is_contiguous(memory_format=torch.channels_last_3d) + grad_output.is_contiguous(memory_format=torch.channels_last) + or grad_output.is_contiguous(memory_format=torch.channels_last_3d) ): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors @@ -123,7 +135,12 @@ def backward(self, grad_output): if saved_input.numel() > 0: # calculate local stats as well as grad_weight / grad_bias - sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( + ( + sum_dy, + sum_dy_xmu, + grad_weight, + grad_bias, + ) = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, @@ -131,7 +148,7 @@ def backward(self, grad_output): weight, self.needs_input_grad[0], self.needs_input_grad[1], - self.needs_input_grad[2] + self.needs_input_grad[2], ) if self.needs_input_grad[0]: @@ -139,7 +156,11 @@ def backward(self, grad_output): num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) torch.distributed.all_reduce( - combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) sum_dy, sum_dy_xmu = torch.split(combined, num_channels) # backward pass for gradient calculation @@ -153,7 +174,7 @@ def backward(self, grad_output): weight, sum_dy, sum_dy_xmu, - count_tensor + count_tensor, ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. @@ -172,20 +193,22 @@ def backward(self, grad_output): if self.needs_input_grad[0]: # launch all_reduce to unblock other peer processes combined = torch.zeros( - 2 * num_channels, - dtype=saved_input.dtype, - device=saved_input.device + 2 * num_channels, dtype=saved_input.dtype, device=saved_input.device ) torch.distributed.all_reduce( - combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) + combined, + torch.distributed.ReduceOp.SUM, + process_group, + async_op=False, + ) # Leave grad_input, grad_weight and grad_bias as None, which will be # interpreted by the autograd engine as Tensors full of zeros. return grad_input, grad_weight, grad_bias, None, None, None, None, None, None -class CrossMapLRN2d(Function): +class CrossMapLRN2d(Function): @staticmethod def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.size = size @@ -195,7 +218,9 @@ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.scale = None if input.dim() != 4: - raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.") + raise ValueError( + f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead." + ) ctx.scale = ctx.scale or input.new() output = input.new() @@ -253,8 +278,7 @@ def backward(ctx, grad_output): input_height = input.size(2) input_width = input.size(3) - paddded_ratio = input.new(channels + ctx.size - 1, input_height, - input_width) + paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width) accum_ratio = input.new(input_height, input_width) cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size @@ -264,20 +288,26 @@ def backward(ctx, grad_output): torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) paddded_ratio.zero_() - padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, - channels) + padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels) for n in range(batch_size): torch.mul(grad_output[n], output[n], out=padded_ratio_center) padded_ratio_center.div_(ctx.scale[n]) torch.sum( - paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio) + paddded_ratio.narrow(0, 0, ctx.size - 1), + 0, + keepdim=False, + out=accum_ratio, + ) for c in range(channels): accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) - grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value) + grad_input[n][c].addcmul_( + input[n][c], accum_ratio, value=-cache_ratio_value + ) accum_ratio.add_(paddded_ratio[c], alpha=-1) return grad_input, None, None, None, None + class BackwardHookFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args): diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 3d8b65175956df..4889a4485c4901 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -3,17 +3,46 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F from torch import Tensor -from .linear import NonDynamicallyQuantizableLinear from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.parameter import Parameter + +from .linear import NonDynamicallyQuantizableLinear from .module import Module -from .. import functional as F -__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh', - 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU', - 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink', - 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax'] + +__all__ = [ + "Threshold", + "ReLU", + "RReLU", + "Hardtanh", + "ReLU6", + "Sigmoid", + "Hardsigmoid", + "Tanh", + "SiLU", + "Mish", + "Hardswish", + "ELU", + "CELU", + "SELU", + "GLU", + "GELU", + "Hardshrink", + "LeakyReLU", + "LogSigmoid", + "Softplus", + "Softshrink", + "MultiheadAttention", + "PReLU", + "Softsign", + "Tanhshrink", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", +] class Threshold(Module): @@ -44,7 +73,7 @@ class Threshold(Module): >>> output = m(input) """ - __constants__ = ['threshold', 'value', 'inplace'] + __constants__ = ["threshold", "value", "inplace"] threshold: float value: float @@ -61,8 +90,8 @@ def forward(self, input: Tensor) -> Tensor: return F.threshold(input, self.threshold, self.value, self.inplace) def extra_repr(self): - inplace_str = ', inplace=True' if self.inplace else '' - return f'threshold={self.threshold}, value={self.value}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"threshold={self.threshold}, value={self.value}{inplace_str}" class ReLU(Module): @@ -93,7 +122,7 @@ class ReLU(Module): >>> output = torch.cat((m(input), m(-input))) """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool def __init__(self, inplace: bool = False): @@ -104,7 +133,7 @@ def forward(self, input: Tensor) -> Tensor: return F.relu(input, inplace=self.inplace) def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' + inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -146,17 +175,14 @@ class RReLU(Module): """ - __constants__ = ['lower', 'upper', 'inplace'] + __constants__ = ["lower", "upper", "inplace"] lower: float upper: float inplace: bool def __init__( - self, - lower: float = 1. / 8, - upper: float = 1. / 3, - inplace: bool = False + self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False ): super().__init__() self.lower = lower @@ -167,8 +193,8 @@ def forward(self, input: Tensor) -> Tensor: return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) def extra_repr(self): - inplace_str = ', inplace=True' if self.inplace else '' - return f'lower={self.lower}, upper={self.upper}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"lower={self.lower}, upper={self.upper}{inplace_str}" class Hardtanh(Module): @@ -204,7 +230,7 @@ class Hardtanh(Module): >>> output = m(input) """ - __constants__ = ['min_val', 'max_val', 'inplace'] + __constants__ = ["min_val", "max_val", "inplace"] min_val: float max_val: float @@ -212,11 +238,11 @@ class Hardtanh(Module): def __init__( self, - min_val: float = -1., - max_val: float = 1., + min_val: float = -1.0, + max_val: float = 1.0, inplace: bool = False, min_value: Optional[float] = None, - max_value: Optional[float] = None + max_value: Optional[float] = None, ) -> None: super().__init__() if min_value is not None: @@ -243,8 +269,8 @@ def forward(self, input: Tensor) -> Tensor: return F.hardtanh(input, self.min_val, self.max_val, self.inplace) def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return f'min_val={self.min_val}, max_val={self.max_val}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}" class ReLU6(Hardtanh): @@ -270,10 +296,10 @@ class ReLU6(Hardtanh): """ def __init__(self, inplace: bool = False): - super().__init__(0., 6., inplace) + super().__init__(0.0, 6.0, inplace) def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' + inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -329,11 +355,11 @@ class Hardsigmoid(Module): >>> output = m(input) """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool - def __init__(self, inplace : bool = False) -> None: + def __init__(self, inplace: bool = False) -> None: super().__init__() self.inplace = inplace @@ -365,6 +391,7 @@ class Tanh(Module): def forward(self, input: Tensor) -> Tensor: return torch.tanh(input) + class SiLU(Module): r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. @@ -394,7 +421,7 @@ class SiLU(Module): >>> output = m(input) """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool def __init__(self, inplace: bool = False): @@ -405,9 +432,10 @@ def forward(self, input: Tensor) -> Tensor: return F.silu(input, inplace=self.inplace) def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' + inplace_str = "inplace=True" if self.inplace else "" return inplace_str + class Mish(Module): r"""Applies the Mish function, element-wise. @@ -432,7 +460,7 @@ class Mish(Module): >>> output = m(input) """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool def __init__(self, inplace: bool = False): @@ -443,9 +471,10 @@ def forward(self, input: Tensor) -> Tensor: return F.mish(input, inplace=self.inplace) def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' + inplace_str = "inplace=True" if self.inplace else "" return inplace_str + class Hardswish(Module): r"""Applies the Hardswish function, element-wise. @@ -476,11 +505,11 @@ class Hardswish(Module): >>> output = m(input) """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool - def __init__(self, inplace : bool = False) -> None: + def __init__(self, inplace: bool = False) -> None: super().__init__() self.inplace = inplace @@ -519,11 +548,11 @@ class ELU(Module): >>> output = m(input) """ - __constants__ = ['alpha', 'inplace'] + __constants__ = ["alpha", "inplace"] alpha: float inplace: bool - def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() self.alpha = alpha self.inplace = inplace @@ -532,8 +561,8 @@ def forward(self, input: Tensor) -> Tensor: return F.elu(input, self.alpha, self.inplace) def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return f'alpha={self.alpha}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" class CELU(Module): @@ -564,11 +593,11 @@ class CELU(Module): https://arxiv.org/abs/1704.07483 """ - __constants__ = ['alpha', 'inplace'] + __constants__ = ["alpha", "inplace"] alpha: float inplace: bool - def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() self.alpha = alpha self.inplace = inplace @@ -577,8 +606,8 @@ def forward(self, input: Tensor) -> Tensor: return F.celu(input, self.alpha, self.inplace) def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return f'alpha={self.alpha}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"alpha={self.alpha}{inplace_str}" class SELU(Module): @@ -616,7 +645,7 @@ class SELU(Module): .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 """ - __constants__ = ['inplace'] + __constants__ = ["inplace"] inplace: bool def __init__(self, inplace: bool = False) -> None: @@ -627,7 +656,7 @@ def forward(self, input: Tensor) -> Tensor: return F.selu(input, self.inplace) def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' + inplace_str = "inplace=True" if self.inplace else "" return inplace_str @@ -652,7 +681,7 @@ class GLU(Module): >>> output = m(input) """ - __constants__ = ['dim'] + __constants__ = ["dim"] dim: int def __init__(self, dim: int = -1) -> None: @@ -663,7 +692,7 @@ def forward(self, input: Tensor) -> Tensor: return F.glu(input, self.dim) def extra_repr(self) -> str: - return f'dim={self.dim}' + return f"dim={self.dim}" class GELU(Module): @@ -694,10 +723,10 @@ class GELU(Module): >>> output = m(input) """ - __constants__ = ['approximate'] + __constants__ = ["approximate"] approximate: str - def __init__(self, approximate: str = 'none') -> None: + def __init__(self, approximate: str = "none") -> None: super().__init__() self.approximate = approximate @@ -705,7 +734,7 @@ def forward(self, input: Tensor) -> Tensor: return F.gelu(input, approximate=self.approximate) def extra_repr(self) -> str: - return f'approximate={repr(self.approximate)}' + return f"approximate={repr(self.approximate)}" class Hardshrink(Module): @@ -737,7 +766,7 @@ class Hardshrink(Module): >>> output = m(input) """ - __constants__ = ['lambd'] + __constants__ = ["lambd"] lambd: float def __init__(self, lambd: float = 0.5) -> None: @@ -748,7 +777,7 @@ def forward(self, input: Tensor) -> Tensor: return F.hardshrink(input, self.lambd) def extra_repr(self) -> str: - return f'{self.lambd}' + return f"{self.lambd}" class LeakyReLU(Module): @@ -786,7 +815,7 @@ class LeakyReLU(Module): >>> output = m(input) """ - __constants__ = ['inplace', 'negative_slope'] + __constants__ = ["inplace", "negative_slope"] inplace: bool negative_slope: float @@ -799,8 +828,8 @@ def forward(self, input: Tensor) -> Tensor: return F.leaky_relu(input, self.negative_slope, self.inplace) def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return f'negative_slope={self.negative_slope}{inplace_str}' + inplace_str = ", inplace=True" if self.inplace else "" + return f"negative_slope={self.negative_slope}{inplace_str}" class LogSigmoid(Module): @@ -855,7 +884,7 @@ class Softplus(Module): >>> output = m(input) """ - __constants__ = ['beta', 'threshold'] + __constants__ = ["beta", "threshold"] beta: float threshold: float @@ -868,7 +897,7 @@ def forward(self, input: Tensor) -> Tensor: return F.softplus(input, self.beta, self.threshold) def extra_repr(self) -> str: - return f'beta={self.beta}, threshold={self.threshold}' + return f"beta={self.beta}, threshold={self.threshold}" class Softshrink(Module): @@ -898,7 +927,7 @@ class Softshrink(Module): >>> output = m(input) """ - __constants__ = ['lambd'] + __constants__ = ["lambd"] lambd: float def __init__(self, lambd: float = 0.5) -> None: @@ -914,7 +943,11 @@ def extra_repr(self) -> str: def _check_arg_device(x: Optional[torch.Tensor]) -> bool: if x is not None: - return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + return x.device.type in [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] return True @@ -926,8 +959,13 @@ def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool: def _is_make_fx_tracing(): if not torch.jit.is_scripting(): - torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack() - return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack) + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + return any( + type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode + for x in torch_dispatch_mode_stack + ) else: return False @@ -995,18 +1033,30 @@ class MultiheadAttention(Module): """ - __constants__ = ['batch_first'] + __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, - kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: if embed_dim <= 0 or num_heads <= 0: raise ValueError( f"embed_dim and num_heads must be greater than 0," f" got embed_dim={embed_dim} and num_heads={num_heads} instead" ) - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -1017,24 +1067,36 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=Fals self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" if not self._qkv_same_embed_dim: - self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) - self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) - self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) - self.register_parameter('in_proj_weight', None) + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) - self.register_parameter('q_proj_weight', None) - self.register_parameter('k_proj_weight', None) - self.register_parameter('v_proj_weight', None) + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) else: - self.register_parameter('in_proj_bias', None) - self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) if add_bias_kv: self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) @@ -1055,8 +1117,8 @@ def _reset_parameters(self): xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: - constant_(self.in_proj_bias, 0.) - constant_(self.out_proj.bias, 0.) + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: @@ -1064,84 +1126,88 @@ def _reset_parameters(self): def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 - if '_qkv_same_embed_dim' not in state: - state['_qkv_same_embed_dim'] = True + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True super().__setstate__(state) def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, - is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]: + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: r"""Compute attention outputs using query, key, and value embeddings. - Supports optional parameters for padding, masks and attention weights. + Supports optional parameters for padding, masks and attention weights. - Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and float masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` - and achieve the best performance for MHA. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - If both attn_mask and key_padding_mask are supplied, their types should match. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - is_causal: If specified, applies a causal mask as attention mask. - Default: ``False``. - Warning: - ``is_causal`` provides a hint that ``attn_mask`` is the - causal mask. Providing incorrect hints can result in - incorrect execution, including forward and backward - compatibility. - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - .. note:: - `batch_first` argument is ignored for unbatched inputs. - """ - why_not_fast_path = '' - if ((attn_mask is not None and torch.is_floating_point(attn_mask)) - or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)): + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and float masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention`` + and achieve the best performance for MHA. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + If both attn_mask and key_padding_mask are supplied, their types should match. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + is_causal: If specified, applies a causal mask as attention mask. + Default: ``False``. + Warning: + ``is_causal`` provides a hint that ``attn_mask`` is the + causal mask. Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ # noqa: B950 + why_not_fast_path = "" + if ( + (attn_mask is not None and torch.is_floating_point(attn_mask)) + or (key_padding_mask is not None) + and torch.is_floating_point(key_padding_mask) + ): why_not_fast_path = "floating-point masks are not supported for fast path." is_batched = query.dim() == 3 @@ -1151,7 +1217,7 @@ def forward( mask_name="key_padding_mask", other_type=F._none_or_dtype(attn_mask), other_name="attn_mask", - target_type=query.dtype + target_type=query.dtype, ) attn_mask = F._canonical_mask( @@ -1168,7 +1234,9 @@ def forward( if not is_fastpath_enabled: why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" elif not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where @@ -1195,7 +1263,9 @@ def forward( why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): + elif query.is_nested and ( + key_padding_mask is not None or attn_mask is not None + ): why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ is not supported with NestedTensor input" elif torch.is_autocast_enabled(): @@ -1218,13 +1288,21 @@ def forward( elif _is_make_fx_tracing(): why_not_fast_path = "we are running make_fx tracing" elif not all(_check_arg_device(x) for x in tensor_args): - why_not_fast_path = ("some Tensor argument's device is neither one of " - f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}") - elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args): - why_not_fast_path = ("grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") + why_not_fast_path = ( + "some Tensor argument's device is neither one of " + f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}" + ) + elif torch.is_grad_enabled() and any( + _arg_requires_grad(x) for x in tensor_args + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) if not why_not_fast_path: - merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query) + merged_mask, mask_type = self.merge_masks( + attn_mask, key_padding_mask, query + ) if self.in_proj_bias is not None and self.in_proj_weight is not None: return torch._native_multi_head_attention( @@ -1240,11 +1318,14 @@ def forward( merged_mask, need_weights, average_attn_weights, - mask_type) + mask_type, + ) any_nested = query.is_nested or key.is_nested or value.is_nested - assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + - f"The fast path was not hit because {why_not_fast_path}") + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) if self.batch_first and is_batched: # make sure that the transpose op does not affect the "is" property @@ -1259,37 +1340,63 @@ def forward( if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, + key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights, - is_causal=is_causal) + is_causal=is_causal, + ) else: attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, average_attn_weights=average_attn_weights, - is_causal=is_causal) + is_causal=is_causal, + ) if self.batch_first and is_batched: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights - def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], - query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: + def merge_masks( + self, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + query: Tensor, + ) -> Tuple[Optional[Tensor], Optional[int]]: r"""Determine mask type and combine masks if necessary. If only one mask is provided, that mask @@ -1320,11 +1427,15 @@ def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Te if attn_mask.dim() == 3: attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len) else: # attn_mask.dim() == 2: - attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand( + batch_size, self.num_heads, -1, -1 + ) merged_mask = attn_mask_expanded if key_padding_mask is not None: - key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1) + key_padding_mask_expanded = key_padding_mask.view( + batch_size, 1, 1, seq_len + ).expand(-1, self.num_heads, -1, -1) merged_mask = attn_mask_expanded + key_padding_mask_expanded # no attn_mask and no key_padding_mask, returns None, None @@ -1381,12 +1492,13 @@ class PReLU(Module): >>> output = m(input) """ - __constants__ = ['num_parameters'] + __constants__ = ["num_parameters"] num_parameters: int - def __init__(self, num_parameters: int = 1, init: float = 0.25, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} self.num_parameters = num_parameters super().__init__() self.init = init @@ -1400,7 +1512,7 @@ def forward(self, input: Tensor) -> Tensor: return F.prelu(input, self.weight) def extra_repr(self) -> str: - return f'num_parameters={self.num_parameters}' + return f"num_parameters={self.num_parameters}" class Softsign(Module): @@ -1480,7 +1592,7 @@ class Softmin(Module): >>> output = m(input) """ - __constants__ = ['dim'] + __constants__ = ["dim"] dim: Optional[int] def __init__(self, dim: Optional[int] = None) -> None: @@ -1489,14 +1601,15 @@ def __init__(self, dim: Optional[int] = None) -> None: def __setstate__(self, state): super().__setstate__(state) - if not hasattr(self, 'dim'): + if not hasattr(self, "dim"): self.dim = None def forward(self, input: Tensor) -> Tensor: return F.softmin(input, self.dim, _stacklevel=5) def extra_repr(self): - return f'dim={self.dim}' + return f"dim={self.dim}" + class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor. @@ -1538,7 +1651,7 @@ class Softmax(Module): """ - __constants__ = ['dim'] + __constants__ = ["dim"] dim: Optional[int] def __init__(self, dim: Optional[int] = None) -> None: @@ -1547,14 +1660,14 @@ def __init__(self, dim: Optional[int] = None) -> None: def __setstate__(self, state): super().__setstate__(state) - if not hasattr(self, 'dim'): + if not hasattr(self, "dim"): self.dim = None def forward(self, input: Tensor) -> Tensor: return F.softmax(input, self.dim, _stacklevel=5) def extra_repr(self) -> str: - return f'dim={self.dim}' + return f"dim={self.dim}" class Softmax2d(Module): @@ -1614,7 +1727,7 @@ class LogSoftmax(Module): >>> output = m(input) """ - __constants__ = ['dim'] + __constants__ = ["dim"] dim: Optional[int] def __init__(self, dim: Optional[int] = None) -> None: @@ -1623,11 +1736,11 @@ def __init__(self, dim: Optional[int] = None) -> None: def __setstate__(self, state): super().__setstate__(state) - if not hasattr(self, 'dim'): + if not hasattr(self, "dim"): self.dim = None def forward(self, input: Tensor) -> Tensor: return F.log_softmax(input, self.dim, _stacklevel=5) def extra_repr(self): - return f'dim={self.dim}' + return f"dim={self.dim}" diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index a6c2da5f596f8d..2a3b24b9b280bf 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -1,19 +1,20 @@ # mypy: allow-untyped-defs from collections import namedtuple +from typing import List, Sequence import torch - +import torch.nn.functional as F from torch import Tensor -from typing import List, Sequence -from . import Sequential, ModuleList, Linear +from .container import ModuleList, Sequential +from .linear import Linear from .module import Module -from ..functional import log_softmax -__all__ = ['AdaptiveLogSoftmaxWithLoss'] -_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) +__all__ = ["AdaptiveLogSoftmaxWithLoss"] + +_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"]) class AdaptiveLogSoftmaxWithLoss(Module): @@ -117,28 +118,31 @@ def __init__( in_features: int, n_classes: int, cutoffs: Sequence[int], - div_value: float = 4., + div_value: float = 4.0, head_bias: bool = False, device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() cutoffs = list(cutoffs) - if (len(cutoffs) == 0): + if len(cutoffs) == 0: raise ValueError("cutoffs should be a sequence of length larger than 0") - if (cutoffs != sorted(cutoffs)) \ - or (min(cutoffs) <= 0) \ - or (max(cutoffs) > (n_classes - 1)) \ - or (len(set(cutoffs)) != len(cutoffs)) \ - or any(int(c) != c for c in cutoffs): - - raise ValueError("cutoffs should be a sequence of unique, positive " - "integers sorted in an increasing order, where " - "each value is between 1 and n_classes-1") + if ( + (cutoffs != sorted(cutoffs)) + or (min(cutoffs) <= 0) + or (max(cutoffs) > (n_classes - 1)) + or (len(set(cutoffs)) != len(cutoffs)) + or any(int(c) != c for c in cutoffs) + ): + raise ValueError( + "cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1" + ) self.in_features = in_features self.n_classes = n_classes @@ -150,12 +154,12 @@ def __init__( self.n_clusters = len(self.cutoffs) - 1 self.head_size = self.shortlist_size + self.n_clusters - self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, - **factory_kwargs) + self.head = Linear( + self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs + ) self.tail = ModuleList() for i in range(self.n_clusters): - hsz = int(self.in_features // (self.div_value ** (i + 1))) osz = self.cutoffs[i + 1] - self.cutoffs[i] @@ -177,18 +181,27 @@ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: if targ_dim == 1: if input_.size(0) != target_.size(0): - raise RuntimeError('Input and target should have the same size ' - 'in the batch dimension.') + raise RuntimeError( + "Input and target should have the same size " + "in the batch dimension." + ) if input_.dim() != 2: - raise RuntimeError('1D target tensor expects 2D input tensors, ' - 'but found inputs with size', input_.size()) + raise RuntimeError( + "1D target tensor expects 2D input tensors, " + "but found inputs with size", + input_.size(), + ) elif targ_dim == 0: if input_.dim() != 1: - raise RuntimeError('0D target tensor expects 1D input tensors, ' - 'but found inputs with size', input_.size()) + raise RuntimeError( + "0D target tensor expects 1D input tensors, " + "but found inputs with size", + input_.size(), + ) else: - raise RuntimeError('0D or 1D target tensor expected, ' - 'multi-target not supported') + raise RuntimeError( + "0D or 1D target tensor expected, " "multi-target not supported" + ) is_batched = targ_dim > 0 input = input_ if is_batched else input_.unsqueeze(0) @@ -202,7 +215,6 @@ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: cutoff_values = [0] + self.cutoffs for i in range(len(cutoff_values) - 1): - low_idx = cutoff_values[i] high_idx = cutoff_values[i + 1] @@ -223,19 +235,21 @@ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: cluster_index = self.shortlist_size + i - 1 gather_inds.index_fill_(0, row_indices, cluster_index) - cluster_logprob = log_softmax(cluster_output, dim=1) + cluster_logprob = F.log_softmax(cluster_output, dim=1) local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) output.index_copy_(0, row_indices, local_logprob.squeeze(1)) used_rows += row_indices.numel() if used_rows != batch_size: - raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], " - f"but values in range [{target.min().item()}, {target.max().item()}] " - "were found. ") + raise RuntimeError( + f"Target values should be in [0, {self.n_classes - 1}], " + f"but values in range [{target.min().item()}, {target.max().item()}] " + "were found. " + ) head_output = self.head(input) - head_logprob = log_softmax(head_output, dim=1) + head_logprob = F.log_softmax(head_output, dim=1) output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() loss = (-output).mean() @@ -247,14 +261,16 @@ def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: def _get_full_log_prob(self, input, head_output): """Given input tensor, and output of ``self.head``, compute the log of the full distribution.""" out = input.new_empty((head_output.size(0), self.n_classes)) - head_logprob = log_softmax(head_output, dim=1) + head_logprob = F.log_softmax(head_output, dim=1) - out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): cluster_output = self.tail[i](input) - cluster_logprob = log_softmax(cluster_output, dim=1) - output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + cluster_logprob = F.log_softmax(cluster_output, dim=1) + output_logprob = cluster_logprob + head_logprob[ + :, self.shortlist_size + i + ].unsqueeze(1) out[:, start_idx:stop_idx] = output_logprob @@ -296,7 +312,7 @@ def predict(self, input: Tensor) -> Tensor: """ head_output = self.head(input) output = torch.argmax(head_output, dim=1) - not_in_shortlist = (output >= self.shortlist_size) + not_in_shortlist = output >= self.shortlist_size all_in_shortlist = not (not_in_shortlist.any()) if all_in_shortlist: @@ -307,7 +323,8 @@ def predict(self, input: Tensor) -> Tensor: return torch.argmax(log_prob, dim=1) else: - log_prob = self._get_full_log_prob(input[not_in_shortlist], - head_output[not_in_shortlist]) + log_prob = self._get_full_log_prob( + input[not_in_shortlist], head_output[not_in_shortlist] + ) output[not_in_shortlist] = torch.argmax(log_prob, dim=1) return output diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 75c8b5504d46b2..8ba9ad24f1165f 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,18 +1,25 @@ # mypy: allow-untyped-defs -from typing import Optional, Any +from typing import Any, Optional import torch from torch import Tensor -from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer +from torch.nn import functional as F, init +from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter -from .. import functional as F -from .. import init from ._functions import SyncBatchNorm as sync_batch_norm from .lazy import LazyModuleMixin from .module import Module -__all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d', - 'LazyBatchNorm3d', 'SyncBatchNorm'] + +__all__ = [ + "BatchNorm1d", + "LazyBatchNorm1d", + "BatchNorm2d", + "LazyBatchNorm2d", + "BatchNorm3d", + "LazyBatchNorm3d", + "SyncBatchNorm", +] class _NormBase(Module): @@ -36,9 +43,9 @@ def __init__( affine: bool = True, track_running_stats: bool = True, device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.num_features = num_features self.eps = eps @@ -52,13 +59,22 @@ def __init__( self.register_parameter("weight", None) self.register_parameter("bias", None) if self.track_running_stats: - self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) - self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) + self.register_buffer( + "running_mean", torch.zeros(num_features, **factory_kwargs) + ) + self.register_buffer( + "running_var", torch.ones(num_features, **factory_kwargs) + ) self.running_mean: Optional[Tensor] self.running_var: Optional[Tensor] - self.register_buffer('num_batches_tracked', - torch.tensor(0, dtype=torch.long, - **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) + self.register_buffer( + "num_batches_tracked", + torch.tensor( + 0, + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ), + ) self.num_batches_tracked: Optional[Tensor] else: self.register_buffer("running_mean", None) @@ -108,7 +124,8 @@ def _load_from_state_dict( if num_batches_tracked_key not in state_dict: state_dict[num_batches_tracked_key] = ( self.num_batches_tracked - if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta') + if self.num_batches_tracked is not None + and self.num_batches_tracked.device != torch.device("meta") else torch.tensor(0, dtype=torch.long) ) @@ -132,9 +149,9 @@ def __init__( affine: bool = True, track_running_stats: bool = True, device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) @@ -189,13 +206,19 @@ def forward(self, input: Tensor) -> Tensor: class _LazyNormBase(LazyModuleMixin, _NormBase): - weight: UninitializedParameter # type: ignore[assignment] bias: UninitializedParameter # type: ignore[assignment] - def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( # affine and track_running_stats are hardcoded to False to # avoid creating tensors that will soon be overwritten. @@ -215,7 +238,10 @@ def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True self.running_mean = UninitializedBuffer(**factory_kwargs) self.running_var = UninitializedBuffer(**factory_kwargs) self.num_batches_tracked = torch.tensor( - 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) + 0, + dtype=torch.long, + **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, + ) def reset_parameters(self) -> None: if not self.has_uninitialized_params() and self.num_features != 0: @@ -230,8 +256,12 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] self.weight.materialize((self.num_features,)) self.bias.materialize((self.num_features,)) if self.track_running_stats: - self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr] - self.running_var.materialize((self.num_features,)) # type:ignore[union-attr] + self.running_mean.materialize( # type:ignore[union-attr] + (self.num_features,) + ) + self.running_var.materialize( # type:ignore[union-attr] + (self.num_features,) + ) self.reset_parameters() @@ -308,9 +338,7 @@ class BatchNorm1d(_BatchNorm): def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: - raise ValueError( - f"expected 2D or 3D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): @@ -344,9 +372,7 @@ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: - raise ValueError( - f"expected 2D or 3D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") class BatchNorm2d(_BatchNorm): @@ -683,9 +709,9 @@ def __init__( track_running_stats: bool = True, process_group: Optional[Any] = None, device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( num_features, eps, momentum, affine, track_running_stats, **factory_kwargs ) @@ -693,9 +719,7 @@ def __init__( def _check_input_dim(self, input): if input.dim() < 2: - raise ValueError( - f"expected at least 2D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected at least 2D input (got {input.dim()}D input)") def _check_non_zero_input_channels(self, input): if input.size(1) == 0: @@ -746,13 +770,22 @@ def forward(self, input: Tensor) -> Tensor: ) # Don't sync batchnorm stats in inference mode (model.eval()). - need_sync = (bn_training and self.training and - torch.distributed.is_available() and torch.distributed.is_initialized()) + need_sync = ( + bn_training + and self.training + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ) if need_sync: # currently only GPU/PrivateUse1 input is supported - if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]: - raise ValueError("SyncBatchNorm expected input tensor to be on GPU or " - f"{torch._C._get_privateuse1_backend_name()}") + if input.device.type not in [ + "cuda", + torch._C._get_privateuse1_backend_name(), + ]: + raise ValueError( + "SyncBatchNorm expected input tensor to be on GPU or " + f"{torch._C._get_privateuse1_backend_name()}" + ) process_group = torch.distributed.group.WORLD if self.process_group: diff --git a/torch/nn/modules/channelshuffle.py b/torch/nn/modules/channelshuffle.py index ff4c8c28d194da..12f31f14568bba 100644 --- a/torch/nn/modules/channelshuffle.py +++ b/torch/nn/modules/channelshuffle.py @@ -1,9 +1,11 @@ +import torch.nn.functional as F +from torch import Tensor + from .module import Module -from .. import functional as F -from torch import Tensor -__all__ = ['ChannelShuffle'] +__all__ = ["ChannelShuffle"] + class ChannelShuffle(Module): r"""Divides and rearranges the channels in a tensor. @@ -40,7 +42,7 @@ class ChannelShuffle(Module): [15., 16.]]]]) """ - __constants__ = ['groups'] + __constants__ = ["groups"] groups: int def __init__(self, groups: int) -> None: @@ -51,4 +53,4 @@ def forward(self, input: Tensor) -> Tensor: return F.channel_shuffle(input, self.groups) def extra_repr(self) -> str: - return f'groups={self.groups}' + return f"groups={self.groups}" diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index c82d8d7d3037a5..0b4c80affd5f9b 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,32 +1,50 @@ # mypy: allow-untyped-defs -from collections import OrderedDict, abc as container_abcs -from itertools import chain, islice import operator +from collections import abc as container_abcs, OrderedDict +from itertools import chain, islice +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + Optional, + overload, + Tuple, + TypeVar, + Union, +) +from typing_extensions import deprecated, Self import torch -from .module import Module -from ..parameter import Parameter from torch._jit_internal import _copy_to_script_wrapper +from torch.nn.parameter import Parameter + +from .module import Module -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union -from typing_extensions import Self -from typing_extensions import deprecated -__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict'] +__all__ = [ + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + "ParameterList", + "ParameterDict", +] -T = TypeVar('T', bound=Module) +T = TypeVar("T", bound=Module) # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList def _addindent(s_, numSpaces): - s = s_.split('\n') + s = s_.split("\n") # don't do anything for single-line stuff if len(s) == 1: return s_ first = s.pop(0) - s = [(numSpaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s return s @@ -95,7 +113,7 @@ def __init__(self, *args: Module) -> None: ... @overload - def __init__(self, arg: 'OrderedDict[str, Module]') -> None: + def __init__(self, arg: "OrderedDict[str, Module]") -> None: ... def __init__(self, *args): @@ -112,12 +130,12 @@ def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] size = len(self) idx = operator.index(idx) if not -size <= idx < size: - raise IndexError(f'index {idx} is out of range') + raise IndexError(f"index {idx} is out of range") idx %= size return next(islice(iterator, idx, None)) @_copy_to_script_wrapper - def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]: + def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: if isinstance(idx, slice): return self.__class__(OrderedDict(list(self._modules.items())[idx])) else: @@ -142,7 +160,7 @@ def __delitem__(self, idx: Union[slice, int]) -> None: def __len__(self) -> int: return len(self._modules) - def __add__(self, other) -> 'Sequential': + def __add__(self, other) -> "Sequential": if isinstance(other, Sequential): ret = Sequential() for layer in self: @@ -151,8 +169,10 @@ def __add__(self, other) -> 'Sequential': ret.append(layer) return ret else: - raise ValueError('add operator supports only objects ' - f'of Sequential class, but {str(type(other))} is given.') + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) def pop(self, key: Union[int, slice]) -> Module: v = self[key] @@ -166,14 +186,20 @@ def __iadd__(self, other) -> Self: self.add_module(str(i + offset), module) return self else: - raise ValueError('add operator supports only objects ' - f'of Sequential class, but {str(type(other))} is given.') + raise ValueError( + "add operator supports only objects " + f"of Sequential class, but {str(type(other))} is given." + ) - def __mul__(self, other: int) -> 'Sequential': + def __mul__(self, other: int) -> "Sequential": if not isinstance(other, int): - raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") - elif (other <= 0): - raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) else: combined = Sequential() offset = 0 @@ -183,14 +209,18 @@ def __mul__(self, other: int) -> 'Sequential': offset += 1 return combined - def __rmul__(self, other: int) -> 'Sequential': + def __rmul__(self, other: int) -> "Sequential": return self.__mul__(other) def __imul__(self, other: int) -> Self: if not isinstance(other, int): - raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}") - elif (other <= 0): - raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}") + raise TypeError( + f"unsupported operand type(s) for *: {type(self)} and {type(other)}" + ) + elif other <= 0: + raise ValueError( + f"Non-positive multiplication factor {other} for {type(self)}" + ) else: len_original = len(self) offset = len(self) @@ -219,7 +249,7 @@ def forward(self, input): input = module(input) return input - def append(self, module: Module) -> 'Sequential': + def append(self, module: Module) -> "Sequential": r"""Append a given module to the end. Args: @@ -228,14 +258,12 @@ def append(self, module: Module) -> 'Sequential': self.add_module(str(len(self)), module) return self - def insert(self, index: int, module: Module) -> 'Sequential': + def insert(self, index: int, module: Module) -> "Sequential": if not isinstance(module, Module): - raise AssertionError( - f'module should be of type: {Module}') + raise AssertionError(f"module should be of type: {Module}") n = len(self._modules) if not (-n <= index <= n): - raise IndexError( - f'Index out of range: {index}') + raise IndexError(f"Index out of range: {index}") if index < 0: index += n for i in range(n, index, -1): @@ -243,7 +271,7 @@ def insert(self, index: int, module: Module) -> 'Sequential': self._modules[str(index)] = module return self - def extend(self, sequential) -> 'Sequential': + def extend(self, sequential) -> "Sequential": for layer in sequential: self.append(layer) return self @@ -284,13 +312,13 @@ def _get_abs_string_index(self, idx): """Get the absolute index for the list of modules.""" idx = operator.index(idx) if not (-len(self) <= idx < len(self)): - raise IndexError(f'index {idx} is out of range') + raise IndexError(f"index {idx} is out of range") if idx < 0: idx += len(self) return str(idx) @_copy_to_script_wrapper - def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']: + def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: if isinstance(idx, slice): return self.__class__(list(self._modules.values())[idx]) else: @@ -321,7 +349,7 @@ def __iter__(self) -> Iterator[Module]: def __iadd__(self, modules: Iterable[Module]) -> Self: return self.extend(modules) - def __add__(self, other: Iterable[Module]) -> 'ModuleList': + def __add__(self, other: Iterable[Module]) -> "ModuleList": combined = ModuleList() for i, module in enumerate(chain(self, other)): combined.add_module(str(i), module) @@ -331,7 +359,7 @@ def __repr__(self): """Return a custom repr for ModuleList that compresses repeated module representations.""" list_of_reprs = [repr(item) for item in self] if len(list_of_reprs) == 0: - return self._get_name() + '()' + return self._get_name() + "()" start_end_indices = [[0, 0]] repeated_blocks = [list_of_reprs[0]] @@ -344,7 +372,7 @@ def __repr__(self): repeated_blocks.append(r) lines = [] - main_str = self._get_name() + '(' + main_str = self._get_name() + "(" for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): local_repr = f"({start_id}): {b}" # default repr @@ -355,8 +383,8 @@ def __repr__(self): local_repr = _addindent(local_repr, 2) lines.append(local_repr) - main_str += '\n ' + '\n '.join(lines) + '\n' - main_str += ')' + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" return main_str @_copy_to_script_wrapper @@ -376,7 +404,7 @@ def insert(self, index: int, module: Module) -> None: self._modules[str(i)] = self._modules[str(i - 1)] self._modules[str(index)] = module - def append(self, module: Module) -> 'ModuleList': + def append(self, module: Module) -> "ModuleList": r"""Append a given module to the end of the list. Args: @@ -397,8 +425,10 @@ def extend(self, modules: Iterable[Module]) -> Self: modules (iterable): iterable of modules to append """ if not isinstance(modules, container_abcs.Iterable): - raise TypeError("ModuleList.extend should be called with an " - "iterable, but got " + type(modules).__name__) + raise TypeError( + "ModuleList.extend should be called with an " + "iterable, but got " + type(modules).__name__ + ) offset = len(self) for i, module in enumerate(modules): self.add_module(str(offset + i), module) @@ -521,9 +551,10 @@ def update(self, modules: Mapping[str, Module]) -> None: or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) """ if not isinstance(modules, container_abcs.Iterable): - raise TypeError("ModuleDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) + raise TypeError( + "ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + type(modules).__name__ + ) if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): for key, module in modules.items(): @@ -532,13 +563,15 @@ def update(self, modules: Mapping[str, Module]) -> None: # modules here can be a list with two items for j, m in enumerate(modules): if not isinstance(m, container_abcs.Iterable): - raise TypeError("ModuleDict update sequence element " - "#" + str(j) + " should be Iterable; is" + - type(m).__name__) + raise TypeError( + "ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(m).__name__ + ) if not len(m) == 2: - raise ValueError("ModuleDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") + raise ValueError( + "ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" + ) # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] @@ -584,7 +617,7 @@ def _get_abs_string_index(self, idx): """Get the absolute index for the list of modules.""" idx = operator.index(idx) if not (-len(self) <= idx < len(self)): - raise IndexError(f'index {idx} is out of range') + raise IndexError(f"index {idx} is out of range") if idx < 0: idx += len(self) return str(idx) @@ -633,7 +666,7 @@ def __dir__(self): keys = [key for key in keys if not key.isdigit()] return keys - def append(self, value: Any) -> 'ParameterList': + def append(self, value: Any) -> "ParameterList": """Append a given value at the end of the list. Args: @@ -651,9 +684,13 @@ def extend(self, values: Iterable[Any]) -> Self: values (iterable): iterable of values to append """ # Tensor is an iterable but we never want to unpack it here - if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor): - raise TypeError("ParameterList.extend should be called with an " - "iterable, but got " + type(values).__name__) + if not isinstance(values, container_abcs.Iterable) or isinstance( + values, torch.Tensor + ): + raise TypeError( + "ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__ + ) for value in values: self.append(value) return self @@ -662,23 +699,28 @@ def extra_repr(self) -> str: child_lines = [] for k, p in enumerate(self): if isinstance(p, torch.Tensor): - size_str = 'x'.join(str(size) for size in p.size()) + size_str = "x".join(str(size) for size in p.size()) if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: - device_str = f' ({p.device})' + device_str = f" ({p.device})" else: - device_str = '' - parastr = '{} containing: [{} of size {}{}]'.format( + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( "Parameter" if isinstance(p, Parameter) else "Tensor", - p.dtype, size_str, device_str) - child_lines.append(' (' + str(k) + '): ' + parastr) + p.dtype, + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) else: - child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) - tmpstr = '\n'.join(child_lines) + tmpstr = "\n".join(child_lines) return tmpstr def __call__(self, *args, **kwargs): - raise RuntimeError('ParameterList should not be called.') + raise RuntimeError("ParameterList should not be called.") class ParameterDict(Module): @@ -726,9 +768,11 @@ def __init__(self, parameters: Any = None) -> None: def _key_to_attr(self, key: str) -> str: if not isinstance(key, str): - raise TypeError("Index given to ParameterDict cannot be used as a key as it is " - f"not a string (type is '{type(key).__name__}'). Open an issue on " - "github if you need non-string keys.") + raise TypeError( + "Index given to ParameterDict cannot be used as a key as it is " + f"not a string (type is '{type(key).__name__}'). Open an issue on " + "github if you need non-string keys." + ) else: # Use the key as-is so that `.named_parameters()` returns the right thing return key @@ -763,7 +807,7 @@ def __iter__(self) -> Iterator[str]: def __reversed__(self) -> Iterator[str]: return reversed(list(self._keys)) - def copy(self) -> 'ParameterDict': + def copy(self) -> "ParameterDict": """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" # We have to use an OrderedDict because the ParameterDict constructor # behaves differently on plain dict vs OrderedDict @@ -820,7 +864,9 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: """ return self[key] if key in self else default - def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict': + def fromkeys( + self, keys: Iterable[str], default: Optional[Any] = None + ) -> "ParameterDict": r"""Return a new ParameterDict with the keys provided. Args: @@ -841,7 +887,7 @@ def values(self) -> Iterable[Any]: r"""Return an iterable of the ParameterDict values.""" return (self[k] for k in self._keys) - def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None: + def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None: r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. .. note:: @@ -854,9 +900,10 @@ def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None: key-value pairs of type (string, :class:`~torch.nn.Parameter`) """ if not isinstance(parameters, container_abcs.Iterable): - raise TypeError("ParametersDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(parameters).__name__) + raise TypeError( + "ParametersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(parameters).__name__ + ) if isinstance(parameters, (OrderedDict, ParameterDict)): for key, parameter in parameters.items(): @@ -867,13 +914,15 @@ def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None: else: for j, p in enumerate(parameters): if not isinstance(p, container_abcs.Iterable): - raise TypeError("ParameterDict update sequence element " - "#" + str(j) + " should be Iterable; is" + - type(p).__name__) + raise TypeError( + "ParameterDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) if not len(p) == 2: - raise ValueError("ParameterDict update sequence element " - "#" + str(j) + " has length " + str(len(p)) + - "; 2 is required") + raise ValueError( + "ParameterDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment self[p[0]] = p[1] # type: ignore[assignment] @@ -881,33 +930,38 @@ def extra_repr(self) -> str: child_lines = [] for k, p in self.items(): if isinstance(p, torch.Tensor): - size_str = 'x'.join(str(size) for size in p.size()) + size_str = "x".join(str(size) for size in p.size()) if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: - device_str = f' ({p.device})' + device_str = f" ({p.device})" else: - device_str = '' - parastr = '{} containing: [{} of size {}{}]'.format( + device_str = "" + parastr = "{} containing: [{} of size {}{}]".format( "Parameter" if isinstance(p, Parameter) else "Tensor", - torch.typename(p), size_str, device_str) - child_lines.append(' (' + str(k) + '): ' + parastr) + torch.typename(p), + size_str, + device_str, + ) + child_lines.append(" (" + str(k) + "): " + parastr) else: - child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) - tmpstr = '\n'.join(child_lines) + child_lines.append( + " (" + str(k) + "): Object of type: " + type(p).__name__ + ) + tmpstr = "\n".join(child_lines) return tmpstr def __call__(self, input): - raise RuntimeError('ParameterDict should not be called.') + raise RuntimeError("ParameterDict should not be called.") - def __or__(self, other: 'ParameterDict') -> 'ParameterDict': + def __or__(self, other: "ParameterDict") -> "ParameterDict": copy = self.copy() copy.update(other) return copy - def __ror__(self, other: 'ParameterDict') -> 'ParameterDict': + def __ror__(self, other: "ParameterDict") -> "ParameterDict": copy = other.copy() copy.update(self) return copy - def __ior__(self, other : 'ParameterDict') -> Self: + def __ior__(self, other: "ParameterDict") -> Self: self.update(other) return self diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index fb6a1557aa71bb..ccb628dff6a313 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1,26 +1,37 @@ # mypy: allow-untyped-defs import math +from typing import List, Optional, Tuple, Union +from typing_extensions import deprecated import torch from torch import Tensor +from torch._torch_docs import reproducibility_notes +from torch.nn import functional as F, init +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t from torch.nn.parameter import Parameter, UninitializedParameter -from .. import functional as F -from .. import init + from .lazy import LazyModuleMixin from .module import Module -from .utils import _single, _pair, _triple, _reverse_repeat_tuple -from torch._torch_docs import reproducibility_notes - -from ..common_types import _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple, Union -from typing_extensions import deprecated - -__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', - 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', - 'LazyConvTranspose3d'] - -convolution_notes = \ - {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. +from .utils import _pair, _reverse_repeat_tuple, _single, _triple + + +__all__ = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "LazyConv1d", + "LazyConv2d", + "LazyConv3d", + "LazyConvTranspose1d", + "LazyConvTranspose2d", + "LazyConvTranspose3d", +] + +convolution_notes = { + "groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. :attr:`in_channels` and :attr:`out_channels` must both be divisible by :attr:`groups`. For example, @@ -32,21 +43,28 @@ * At groups= :attr:`in_channels`, each input channel is convolved with its own set of filters (of size :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", - - "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, where `K` is a positive integer, this operation is also known as a "depthwise convolution". In other words, for an input of size :math:`(N, C_{in}, L_{in})`, a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments - :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""", +} # noqa: B950 class _ConvNd(Module): - - __constants__ = ['stride', 'padding', 'dilation', 'groups', - 'padding_mode', 'output_padding', 'in_channels', - 'out_channels', 'kernel_size'] - __annotations__ = {'bias': Optional[torch.Tensor]} + __constants__ = [ + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + "output_padding", + "in_channels", + "out_channels", + "kernel_size", + ] + __annotations__ = {"bias": Optional[torch.Tensor]} def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body] ... @@ -65,39 +83,46 @@ def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) - weight: Tensor bias: Optional[Tensor] - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Tuple[int, ...], - stride: Tuple[int, ...], - padding: Tuple[int, ...], - dilation: Tuple[int, ...], - transposed: bool, - output_padding: Tuple[int, ...], - groups: int, - bias: bool, - padding_mode: str, - device=None, - dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + transposed: bool, + output_padding: Tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if groups <= 0: - raise ValueError('groups must be a positive integer') + raise ValueError("groups must be a positive integer") if in_channels % groups != 0: - raise ValueError('in_channels must be divisible by groups') + raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: - raise ValueError('out_channels must be divisible by groups') - valid_padding_strings = {'same', 'valid'} + raise ValueError("out_channels must be divisible by groups") + valid_padding_strings = {"same", "valid"} if isinstance(padding, str): if padding not in valid_padding_strings: raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}") - if padding == 'same' and any(s != 1 for s in stride): - raise ValueError("padding='same' is not supported for strided convolutions") + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" + ) + if padding == "same" and any(s != 1 for s in stride): + raise ValueError( + "padding='same' is not supported for strided convolutions" + ) - valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} if padding_mode not in valid_padding_modes: - raise ValueError(f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'") + raise ValueError( + f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" + ) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size @@ -114,27 +139,39 @@ def __init__(self, # reverse order than the dimension. if isinstance(self.padding, str): self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == 'same': - for d, k, i in zip(dilation, kernel_size, - range(len(kernel_size) - 1, -1, -1)): + if padding == "same": + for d, k, i in zip( + dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + ): total_padding = d * (k - 1) left_pad = total_padding // 2 self._reversed_padding_repeated_twice[2 * i] = left_pad self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad) + total_padding - left_pad + ) else: - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + self._reversed_padding_repeated_twice = _reverse_repeat_tuple( + self.padding, 2 + ) if transposed: - self.weight = Parameter(torch.empty( - (in_channels, out_channels // groups, *kernel_size), **factory_kwargs)) + self.weight = Parameter( + torch.empty( + (in_channels, out_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) else: - self.weight = Parameter(torch.empty( - (out_channels, in_channels // groups, *kernel_size), **factory_kwargs)) + self.weight = Parameter( + torch.empty( + (out_channels, in_channels // groups, *kernel_size), + **factory_kwargs, + ) + ) if bias: self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() @@ -150,30 +187,33 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def extra_repr(self): - s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' - ', stride={stride}') + s = ( + "{in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) if self.padding != (0,) * len(self.padding): - s += ', padding={padding}' + s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): - s += ', dilation={dilation}' + s += ", dilation={dilation}" if self.output_padding != (0,) * len(self.output_padding): - s += ', output_padding={output_padding}' + s += ", output_padding={output_padding}" if self.groups != 1: - s += ', groups={groups}' + s += ", groups={groups}" if self.bias is None: - s += ', bias=False' - if self.padding_mode != 'zeros': - s += ', padding_mode={padding_mode}' + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" return s.format(**self.__dict__) def __setstate__(self, state): super().__setstate__(state) - if not hasattr(self, 'padding_mode'): - self.padding_mode = 'zeros' + if not hasattr(self, "padding_mode"): + self.padding_mode = "zeros" class Conv1d(_ConvNd): - __doc__ = r"""Applies a 1D convolution over an input signal composed of several input + __doc__ = ( + r"""Applies a 1D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -188,7 +228,8 @@ class Conv1d(_ConvNd): where :math:`\star` is the valid `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`L` is a length of signal sequence. - """ + r""" + """ + + r""" This module supports :ref:`TensorFloat32`. @@ -236,7 +277,10 @@ class Conv1d(_ConvNd): bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` @@ -270,6 +314,7 @@ class Conv1d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + ) def __init__( self, @@ -281,11 +326,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type + padding_mode: str = "zeros", # TODO: refine this type device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} # we create new variables below to make mypy happy since kernel_size has # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] kernel_size_ = _single(kernel_size) @@ -293,23 +338,44 @@ def __init__( padding_ = padding if isinstance(padding, str) else _single(padding) dilation_ = _single(dilation) super().__init__( - in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, - False, _single(0), groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _single(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, bias, self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, weight, bias, self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight, self.bias) class Conv2d(_ConvNd): - __doc__ = r"""Applies a 2D convolution over an input signal composed of several input + __doc__ = ( + r"""Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size @@ -325,7 +391,8 @@ class Conv2d(_ConvNd): :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. - """ + r""" + """ + + r""" This module supports :ref:`TensorFloat32`. @@ -378,7 +445,10 @@ class Conv2d(_ConvNd): channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` @@ -422,6 +492,7 @@ class Conv2d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + ) def __init__( self, @@ -433,32 +504,54 @@ def __init__( dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type + padding_mode: str = "zeros", # TODO: refine this type device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) padding_ = padding if isinstance(padding, str) else _pair(padding) dilation_ = _pair(dilation) super().__init__( - in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, - False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _pair(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, bias, self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, bias, self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.weight, self.bias) + class Conv3d(_ConvNd): - __doc__ = r"""Applies a 3D convolution over an input signal composed of several input + __doc__ = ( + r"""Applies a 3D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` @@ -469,7 +562,8 @@ class Conv3d(_ConvNd): \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) where :math:`\star` is the valid 3D `cross-correlation`_ operator - """ + r""" + """ + + r""" This module supports :ref:`TensorFloat32`. @@ -517,7 +611,10 @@ class Conv3d(_ConvNd): dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` @@ -563,6 +660,7 @@ class Conv3d(_ConvNd): .. _link: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md """ + ) def __init__( self, @@ -574,18 +672,29 @@ def __init__( dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _triple(kernel_size) stride_ = _triple(stride) padding_ = padding if isinstance(padding, str) else _triple(padding) dilation_ = _triple(dilation) super().__init__( - in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, - False, _triple(0), groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size_, + stride_, + padding_, + dilation_, + False, + _triple(0), + groups, + bias, + padding_mode, + **factory_kwargs, + ) def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": @@ -609,23 +718,55 @@ def forward(self, input: Tensor) -> Tensor: class _ConvTransposeNd(_ConvNd): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, transposed, output_padding, - groups, bias, padding_mode, device=None, dtype=None) -> None: - if padding_mode != 'zeros': - raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}') + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + device=None, + dtype=None, + ) -> None: + if padding_mode != "zeros": + raise ValueError( + f'Only "zeros" padding mode is supported for {self.__class__.__name__}' + ) - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - in_channels, out_channels, kernel_size, stride, - padding, dilation, transposed, output_padding, - groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input: Tensor, output_size: Optional[List[int]], - stride: List[int], padding: List[int], kernel_size: List[int], - num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]: + def _output_padding( + self, + input: Tensor, + output_size: Optional[List[int]], + stride: List[int], + padding: List[int], + kernel_size: List[int], + num_spatial_dims: int, + dilation: Optional[List[int]] = None, + ) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: @@ -636,14 +777,19 @@ def _output_padding(self, input: Tensor, output_size: Optional[List[int]], if len(output_size) != num_spatial_dims: raise ValueError( f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} " - f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})") + f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" + ) min_sizes = torch.jit.annotate(List[int], []) max_sizes = torch.jit.annotate(List[int], []) for d in range(num_spatial_dims): - dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] - - 2 * padding[d] + - (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1) + dim_size = ( + (input.size(d + num_non_spatial_dims) - 1) * stride[d] + - 2 * padding[d] + + (dilation[d] if dilation is not None else 1) + * (kernel_size[d] - 1) + + 1 + ) min_sizes.append(dim_size) max_sizes.append(min_sizes[d] + stride[d] - 1) @@ -654,7 +800,8 @@ def _output_padding(self, input: Tensor, output_size: Optional[List[int]], if size < min_size or size > max_size: raise ValueError( f"requested an output size of {output_size}, but valid sizes range " - f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})") + f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" + ) res = torch.jit.annotate(List[int], []) for d in range(num_spatial_dims): @@ -665,7 +812,8 @@ def _output_padding(self, input: Tensor, output_size: Optional[List[int]], class ConvTranspose1d(_ConvTransposeNd): - __doc__ = r"""Applies a 1D transposed convolution operator over an input image + __doc__ = ( + r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv1d with respect to its input. @@ -725,7 +873,10 @@ class ConvTranspose1d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` @@ -753,6 +904,7 @@ class ConvTranspose1d(_ConvTransposeNd): .. _`Deconvolutional Networks`: https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf """ + ) def __init__( self, @@ -765,38 +917,65 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _single(kernel_size) stride = _single(stride) padding = _single(padding) dilation = _single(dilation) output_padding = _single(output_padding) super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - True, output_padding, groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose1d" + ) assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 1 output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] - num_spatial_dims, self.dilation) # type: ignore[arg-type] + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) return F.conv_transpose1d( - input, self.weight, self.bias, self.stride, self.padding, - output_padding, self.groups, self.dilation) + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) class ConvTranspose2d(_ConvTransposeNd): - __doc__ = r"""Applies a 2D transposed convolution operator over an input image + __doc__ = ( + r"""Applies a 2D transposed convolution operator over an input image composed of several input planes. This module can be seen as the gradient of Conv2d with respect to its input. @@ -857,7 +1036,10 @@ class ConvTranspose2d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` @@ -907,6 +1089,7 @@ class ConvTranspose2d(_ConvTransposeNd): .. _`Deconvolutional Networks`: https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf """ + ) def __init__( self, @@ -919,39 +1102,66 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_2_t = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) output_padding = _pair(output_padding) super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - True, output_padding, groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose2d" + ) assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 2 output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] - num_spatial_dims, self.dilation) # type: ignore[arg-type] + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) return F.conv_transpose2d( - input, self.weight, self.bias, self.stride, self.padding, - output_padding, self.groups, self.dilation) + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) class ConvTranspose3d(_ConvTransposeNd): - __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input + __doc__ = ( + r"""Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes. @@ -1014,7 +1224,10 @@ class ConvTranspose3d(_ConvTransposeNd): groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 - """.format(**reproducibility_notes, **convolution_notes) + r""" + """.format( + **reproducibility_notes, **convolution_notes + ) + + r""" Shape: - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` @@ -1059,6 +1272,7 @@ class ConvTranspose3d(_ConvTransposeNd): .. _`Deconvolutional Networks`: https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf """ + ) def __init__( self, @@ -1071,35 +1285,61 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) output_padding = _triple(output_padding) super().__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - True, output_padding, groups, bias, padding_mode, **factory_kwargs) + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + True, + output_padding, + groups, + bias, + padding_mode, + **factory_kwargs, + ) def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - if self.padding_mode != 'zeros': - raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') + if self.padding_mode != "zeros": + raise ValueError( + "Only `zeros` padding mode is supported for ConvTranspose3d" + ) assert isinstance(self.padding, tuple) # One cannot replace List by Tuple or Sequence in "_output_padding" because # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. num_spatial_dims = 3 output_padding = self._output_padding( - input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] - num_spatial_dims, self.dilation) # type: ignore[arg-type] + input, + output_size, + self.stride, # type: ignore[arg-type] + self.padding, # type: ignore[arg-type] + self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, + self.dilation, # type: ignore[arg-type] + ) return F.conv_transpose3d( - input, self.weight, self.bias, self.stride, self.padding, - output_padding, self.groups, self.dilation) + input, + self.weight, + self.bias, + self.stride, + self.padding, + output_padding, + self.groups, + self.dilation, + ) # TODO: Deprecate and remove the following alias `_ConvTransposeMixin`. @@ -1118,7 +1358,6 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten # `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as # above would still work). class _ConvTransposeMixin(_ConvTransposeNd): - @deprecated( "`_ConvTransposeMixin` is a deprecated internal class. " "Please consider using public APIs.", @@ -1156,14 +1395,24 @@ def initialize_parameters(self, input: Tensor, *args, **kwargs) -> None: # type if self.has_uninitialized_params(): # type: ignore[misc] self.in_channels = self._get_in_channels(input) if self.in_channels % self.groups != 0: - raise ValueError('in_channels must be divisible by groups') + raise ValueError("in_channels must be divisible by groups") assert isinstance(self.weight, UninitializedParameter) if self.transposed: - self.weight.materialize(( - self.in_channels, self.out_channels // self.groups, *self.kernel_size)) + self.weight.materialize( + ( + self.in_channels, + self.out_channels // self.groups, + *self.kernel_size, + ) + ) else: - self.weight.materialize(( - self.out_channels, self.in_channels // self.groups, *self.kernel_size)) + self.weight.materialize( + ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) + ) if self.bias is not None: assert isinstance(self.bias, UninitializedParameter) self.bias.materialize((self.out_channels,)) @@ -1175,9 +1424,11 @@ def _get_in_channels(self, input: Tensor) -> int: num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim num_dims_batch = num_dims_no_batch + 1 if input.dim() not in (num_dims_no_batch, num_dims_batch): - raise RuntimeError(f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input " - f"to {self.__class__.__name__}, but " - f"got input of size: {input.shape}") + raise RuntimeError( + f"Expected {num_dims_no_batch}D (unbatched) or {num_dims_batch}D (batched) input " + f"to {self.__class__.__name__}, but " + f"got input of size: {input.shape}" + ) return input.shape[1] if input.dim() == num_dims_batch else input.shape[0] # Function to return the number of spatial dims expected for inputs to the module. @@ -1227,11 +1478,11 @@ def __init__( dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1244,7 +1495,7 @@ def __init__( # that will soon be overwritten. False, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels @@ -1296,11 +1547,11 @@ def __init__( dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', # TODO: refine this type + padding_mode: str = "zeros", # TODO: refine this type device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1313,7 +1564,7 @@ def __init__( # that will soon be overwritten. False, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels @@ -1366,11 +1617,11 @@ def __init__( dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1383,7 +1634,7 @@ def __init__( # that will soon be overwritten. False, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels @@ -1434,11 +1685,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1452,7 +1703,7 @@ def __init__( False, dilation, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels @@ -1503,11 +1754,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: int = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1521,7 +1772,7 @@ def __init__( False, dilation, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels @@ -1572,11 +1823,11 @@ def __init__( groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, - padding_mode: str = 'zeros', + padding_mode: str = "zeros", device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( 0, 0, @@ -1590,7 +1841,7 @@ def __init__( False, dilation, padding_mode, - **factory_kwargs + **factory_kwargs, ) self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index cbf98665799e3d..dfec05da1172e7 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -1,9 +1,11 @@ +import torch.nn.functional as F +from torch import Tensor + from .module import Module -from .. import functional as F -from torch import Tensor -__all__ = ['PairwiseDistance', 'CosineSimilarity'] +__all__ = ["PairwiseDistance", "CosineSimilarity"] + class PairwiseDistance(Module): r""" @@ -39,12 +41,14 @@ class PairwiseDistance(Module): >>> output = pdist(input1, input2) """ - __constants__ = ['norm', 'eps', 'keepdim'] + __constants__ = ["norm", "eps", "keepdim"] norm: float eps: float keepdim: bool - def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None: + def __init__( + self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False + ) -> None: super().__init__() self.norm = p self.eps = eps @@ -76,7 +80,7 @@ class CosineSimilarity(Module): >>> output = cos(input1, input2) """ - __constants__ = ['dim', 'eps'] + __constants__ = ["dim", "eps"] dim: int eps: float diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index f4e151879d7de7..c04f66f7a25a8f 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -1,24 +1,35 @@ +import torch.nn.functional as F +from torch import Tensor + from .module import Module -from .. import functional as F -from torch import Tensor -__all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout'] +__all__ = [ + "Dropout", + "Dropout1d", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] + class _DropoutNd(Module): - __constants__ = ['p', 'inplace'] + __constants__ = ["p", "inplace"] p: float inplace: bool def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super().__init__() if p < 0 or p > 1: - raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + raise ValueError( + f"dropout probability has to be between 0 and 1, but got {p}" + ) self.p = p self.inplace = inplace def extra_repr(self) -> str: - return f'p={self.p}, inplace={self.inplace}' + return f"p={self.p}, inplace={self.inplace}" class Dropout(_DropoutNd): diff --git a/torch/nn/modules/flatten.py b/torch/nn/modules/flatten.py index f1c44fd350d151..dd4aa4799e7fa8 100644 --- a/torch/nn/modules/flatten.py +++ b/torch/nn/modules/flatten.py @@ -1,11 +1,14 @@ # mypy: allow-untyped-defs -from .module import Module - from typing import Tuple, Union + from torch import Tensor from torch.types import _size -__all__ = ['Flatten', 'Unflatten'] +from .module import Module + + +__all__ = ["Flatten", "Unflatten"] + class Flatten(Module): r""" @@ -37,7 +40,7 @@ class Flatten(Module): torch.Size([160, 5]) """ - __constants__ = ['start_dim', 'end_dim'] + __constants__ = ["start_dim", "end_dim"] start_dim: int end_dim: int @@ -50,7 +53,7 @@ def forward(self, input: Tensor) -> Tensor: return input.flatten(self.start_dim, self.end_dim) def extra_repr(self) -> str: - return f'start_dim={self.start_dim}, end_dim={self.end_dim}' + return f"start_dim={self.start_dim}, end_dim={self.end_dim}" class Unflatten(Module): @@ -102,11 +105,13 @@ class Unflatten(Module): NamedShape = Tuple[Tuple[str, int]] - __constants__ = ['dim', 'unflattened_size'] + __constants__ = ["dim", "unflattened_size"] dim: Union[int, str] unflattened_size: Union[_size, NamedShape] - def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: + def __init__( + self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape] + ) -> None: super().__init__() if isinstance(dim, int): @@ -120,26 +125,34 @@ def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedSha self.unflattened_size = unflattened_size def _require_tuple_tuple(self, input): - if (isinstance(input, tuple)): + if isinstance(input, tuple): for idx, elem in enumerate(input): if not isinstance(elem, tuple): - raise TypeError("unflattened_size must be tuple of tuples, " + - f"but found element of type {type(elem).__name__} at pos {idx}") + raise TypeError( + "unflattened_size must be tuple of tuples, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) return - raise TypeError("unflattened_size must be a tuple of tuples, " + - f"but found type {type(input).__name__}") + raise TypeError( + "unflattened_size must be a tuple of tuples, " + + f"but found type {type(input).__name__}" + ) def _require_tuple_int(self, input): - if (isinstance(input, (tuple, list))): + if isinstance(input, (tuple, list)): for idx, elem in enumerate(input): if not isinstance(elem, int): - raise TypeError("unflattened_size must be tuple of ints, " + - f"but found element of type {type(elem).__name__} at pos {idx}") + raise TypeError( + "unflattened_size must be tuple of ints, " + + f"but found element of type {type(elem).__name__} at pos {idx}" + ) return - raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}") + raise TypeError( + f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}" + ) def forward(self, input: Tensor) -> Tensor: return input.unflatten(self.dim, self.unflattened_size) def extra_repr(self) -> str: - return f'dim={self.dim}, unflattened_size={self.unflattened_size}' + return f"dim={self.dim}, unflattened_size={self.unflattened_size}" diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index f8cb08362362fc..54d97d5a58bf10 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -1,10 +1,12 @@ +import torch.nn.functional as F +from torch import Tensor +from torch.nn.common_types import _size_any_t + from .module import Module -from .. import functional as F -from torch import Tensor -from ..common_types import _size_any_t -__all__ = ['Fold', 'Unfold'] +__all__ = ["Fold", "Unfold"] + class Fold(Module): r"""Combines an array of sliding local blocks into a large containing tensor. @@ -118,8 +120,7 @@ class Fold(Module): """ - __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding', - 'stride'] + __constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"] output_size: _size_any_t kernel_size: _size_any_t dilation: _size_any_t @@ -132,7 +133,7 @@ def __init__( kernel_size: _size_any_t, dilation: _size_any_t = 1, padding: _size_any_t = 0, - stride: _size_any_t = 1 + stride: _size_any_t = 1, ) -> None: super().__init__() self.output_size = output_size @@ -142,14 +143,22 @@ def __init__( self.stride = stride def forward(self, input: Tensor) -> Tensor: - return F.fold(input, self.output_size, self.kernel_size, self.dilation, - self.padding, self.stride) + return F.fold( + input, + self.output_size, + self.kernel_size, + self.dilation, + self.padding, + self.stride, + ) def extra_repr(self) -> str: - return 'output_size={output_size}, kernel_size={kernel_size}, ' \ - 'dilation={dilation}, padding={padding}, stride={stride}'.format( + return ( + "output_size={output_size}, kernel_size={kernel_size}, " + "dilation={dilation}, padding={padding}, stride={stride}".format( **self.__dict__ ) + ) class Unfold(Module): @@ -275,7 +284,7 @@ class Unfold(Module): """ - __constants__ = ['kernel_size', 'dilation', 'padding', 'stride'] + __constants__ = ["kernel_size", "dilation", "padding", "stride"] kernel_size: _size_any_t dilation: _size_any_t padding: _size_any_t @@ -286,7 +295,7 @@ def __init__( kernel_size: _size_any_t, dilation: _size_any_t = 1, padding: _size_any_t = 0, - stride: _size_any_t = 1 + stride: _size_any_t = 1, ) -> None: super().__init__() self.kernel_size = kernel_size @@ -295,9 +304,12 @@ def __init__( self.stride = stride def forward(self, input: Tensor) -> Tensor: - return F.unfold(input, self.kernel_size, self.dilation, - self.padding, self.stride) + return F.unfold( + input, self.kernel_size, self.dilation, self.padding, self.stride + ) def extra_repr(self) -> str: - return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \ - ' stride={stride}'.format(**self.__dict__) + return ( + "kernel_size={kernel_size}, dilation={dilation}, padding={padding}," + " stride={stride}".format(**self.__dict__) + ) diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index e6a3e1c0a3a1be..89bdda34f93bb2 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,13 +1,22 @@ # mypy: allow-untyped-defs import warnings + +import torch.nn.functional as F from torch import Tensor from .batchnorm import _LazyNormBase, _NormBase -from .. import functional as F -__all__ = ['InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LazyInstanceNorm1d', - 'LazyInstanceNorm2d', 'LazyInstanceNorm3d'] + +__all__ = [ + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LazyInstanceNorm1d", + "LazyInstanceNorm2d", + "LazyInstanceNorm3d", +] + class _InstanceNorm(_NormBase): def __init__( @@ -18,11 +27,12 @@ def __init__( affine: bool = False, track_running_stats: bool = False, device=None, - dtype=None + dtype=None, ) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + ) def _check_input_dim(self, input): raise NotImplementedError @@ -45,35 +55,51 @@ def _apply_instance_norm(self, input): self.eps, ) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) # at version 1: removed running_mean and running_var when # track_running_stats=False (default) if version is None and not self.track_running_stats: running_stats_keys = [] - for name in ('running_mean', 'running_var'): + for name in ("running_mean", "running_var"): key = prefix + name if key in state_dict: running_stats_keys.append(key) if len(running_stats_keys) > 0: error_msgs.append( - 'Unexpected running stats buffer(s) {names} for {klass} ' - 'with track_running_stats=False. If state_dict is a ' - 'checkpoint saved before 0.4.0, this may be expected ' - 'because {klass} does not track running stats by default ' - 'since 0.4.0. Please remove these keys from state_dict. If ' - 'the running stats are actually needed, instead set ' - 'track_running_stats=True in {klass} to enable them. See ' - 'the documentation of {klass} for details.' - .format(names=" and ".join(f'"{k}"' for k in running_stats_keys), - klass=self.__class__.__name__)) + "Unexpected running stats buffer(s) {names} for {klass} " + "with track_running_stats=False. If state_dict is a " + "checkpoint saved before 0.4.0, this may be expected " + "because {klass} does not track running stats by default " + "since 0.4.0. Please remove these keys from state_dict. If " + "the running stats are actually needed, instead set " + "track_running_stats=True in {klass} to enable them. See " + "the documentation of {klass} for details.".format( + names=" and ".join(f'"{k}"' for k in running_stats_keys), + klass=self.__class__.__name__, + ) + ) for key in running_stats_keys: state_dict.pop(key) super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) @@ -83,11 +109,14 @@ def forward(self, input: Tensor) -> Tensor: if self.affine: raise ValueError( f"expected input's size at dim={feature_dim} to match num_features" - f" ({self.num_features}), but got: {input.size(feature_dim)}.") + f" ({self.num_features}), but got: {input.size(feature_dim)}." + ) else: - warnings.warn(f"input's size at dim={feature_dim} does not match num_features. " - "You can silence this warning by not passing in num_features, " - "which is not used because affine=False") + warnings.warn( + f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False" + ) if input.dim() == self._get_no_batch_dim(): return self._handle_no_batch_input(input) @@ -169,7 +198,7 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (2, 3): - raise ValueError(f'expected 2D or 3D input (got {input.dim()}D input)') + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): @@ -206,7 +235,7 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (2, 3): - raise ValueError(f'expected 2D or 3D input (got {input.dim()}D input)') + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") class InstanceNorm2d(_InstanceNorm): @@ -285,7 +314,7 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (3, 4): - raise ValueError(f'expected 3D or 4D input (got {input.dim()}D input)') + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): @@ -323,7 +352,7 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (3, 4): - raise ValueError(f'expected 3D or 4D input (got {input.dim()}D input)') + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") class InstanceNorm3d(_InstanceNorm): @@ -401,7 +430,7 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (4, 5): - raise ValueError(f'expected 4D or 5D input (got {input.dim()}D input)') + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): @@ -439,4 +468,4 @@ def _get_no_batch_dim(self): def _check_input_dim(self, input): if input.dim() not in (4, 5): - raise ValueError(f'expected 4D or 5D input (got {input.dim()}D input)') + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index f4be1b7db7069b..453b8ec7d122c9 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import itertools -from typing import Protocol, Optional, Type, Any +from typing import Any, Optional, Protocol, Type import torch -from ..parameter import is_lazy +from torch.nn.parameter import is_lazy + + +__all__ = ["LazyModuleMixin"] -__all__ = ['LazyModuleMixin'] class _LazyProtocol(Protocol): """This class is used to avoid errors with mypy checks for the attributes in a mixin. @@ -20,8 +22,15 @@ def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ... def _lazy_load_hook( - self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): ... def _get_name(self): @@ -177,7 +186,9 @@ def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesnt like this super call in a mixin super().__init__(*args, **kwargs) # type: ignore[misc] self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) - self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters, with_kwargs=True) + self._initialize_hook = self.register_forward_pre_hook( + self._infer_parameters, with_kwargs=True + ) def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): # This should be ideally implemented as a hook, @@ -195,8 +206,15 @@ def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): destination[prefix + name] = buf def _lazy_load_hook( - self: _LazyProtocol, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + self: _LazyProtocol, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): """load_state_dict pre-hook function for lazy buffers and parameters. The purpose of this hook is to adjust the current state and/or @@ -206,7 +224,9 @@ def _lazy_load_hook( See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` for the details of the hook specification. """ - for name, param in itertools.chain(self._parameters.items(), self._buffers.items()): + for name, param in itertools.chain( + self._parameters.items(), self._buffers.items() + ): key = prefix + name if key in state_dict and param is not None: input_param = state_dict[key] @@ -223,7 +243,9 @@ def initialize_parameters(self: _LazyProtocol, *args, **kwargs): This adds an interface to isolate parameter initialization from the forward pass when doing parameter shape inference. """ - raise NotImplementedError(f'initialize_parameters is not implemented for {self.__class__.__name__}') + raise NotImplementedError( + f"initialize_parameters is not implemented for {self.__class__.__name__}" + ) def has_uninitialized_params(self: _LazyProtocol): r"""Check if a module has parameters that are not initialized.""" @@ -249,15 +271,18 @@ def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): kwargs = kwargs if kwargs else {} module.initialize_parameters(*args, **kwargs) if module.has_uninitialized_params(): - raise RuntimeError(f'module {self._get_name()} has not been fully initialized') + raise RuntimeError( + f"module {self._get_name()} has not been fully initialized" + ) module._initialize_hook.remove() module._load_hook.remove() - delattr(module, '_initialize_hook') - delattr(module, '_load_hook') + delattr(module, "_initialize_hook") + delattr(module, "_load_hook") if module.cls_to_become is not None: module.__class__ = module.cls_to_become - def _replicate_for_data_parallel(self: _LazyProtocol): - raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. ' - 'Run a dummy forward pass to correctly initialize the modules') + raise RuntimeError( + "Modules with uninitialized parameters can't be used with `DataParallel`. " + "Run a dummy forward pass to correctly initialize the modules" + ) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index be27394623994a..dc5185b7eec0bf 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -4,18 +4,18 @@ import torch from torch import Tensor +from torch.nn import functional as F, init from torch.nn.parameter import Parameter, UninitializedParameter -from .. import functional as F -from .. import init -from .module import Module + from .lazy import LazyModuleMixin +from .module import Module __all__ = [ - 'Bilinear', - 'Identity', - 'LazyLinear', - 'Linear', + "Bilinear", + "Identity", + "LazyLinear", + "Linear", ] @@ -85,22 +85,30 @@ class Linear(Module): torch.Size([128, 30]) """ - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: Tensor - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features - self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + self.weight = Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) if bias: self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: @@ -117,7 +125,7 @@ def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias) def extra_repr(self) -> str: - return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" # This class exists solely to avoid triggering an obscure error when scripting @@ -126,10 +134,17 @@ def extra_repr(self) -> str: # TODO: fail fast on quantization API usage error, then remove this class # and replace uses of it with plain Linear class NonDynamicallyQuantizableLinear(Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - super().__init__(in_features, out_features, bias=bias, - device=device, dtype=dtype) + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__( + in_features, out_features, bias=bias, device=device, dtype=dtype + ) class Bilinear(Module): @@ -170,25 +185,34 @@ class Bilinear(Module): torch.Size([128, 40]) """ - __constants__ = ['in1_features', 'in2_features', 'out_features'] + __constants__ = ["in1_features", "in2_features", "out_features"] in1_features: int in2_features: int out_features: int weight: Tensor - def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + in1_features: int, + in2_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in1_features = in1_features self.in2_features = in2_features self.out_features = out_features - self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs)) + self.weight = Parameter( + torch.empty((out_features, in1_features, in2_features), **factory_kwargs) + ) if bias: self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: @@ -201,8 +225,10 @@ def forward(self, input1: Tensor, input2: Tensor) -> Tensor: return F.bilinear(input1, input2, self.weight, self.bias) def extra_repr(self) -> str: - return (f'in1_features={self.in1_features}, in2_features={self.in2_features}, ' - f'out_features={self.out_features}, bias={self.bias is not None}') + return ( + f"in1_features={self.in1_features}, in2_features={self.in2_features}, " + f"out_features={self.out_features}, bias={self.bias is not None}" + ) class LazyLinear(LazyModuleMixin, Linear): @@ -238,9 +264,10 @@ class LazyLinear(LazyModuleMixin, Linear): weight: UninitializedParameter bias: UninitializedParameter # type: ignore[assignment] - def __init__(self, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, out_features: int, bias: bool = True, device=None, dtype=None + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} # bias is hardcoded to False to avoid creating tensor # that will soon be overwritten. super().__init__(0, 0, False) @@ -261,4 +288,6 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] if self.bias is not None: self.bias.materialize((self.out_features,)) self.reset_parameters() + + # TODO: PartialLinear - maybe in sparse? diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 497da821850626..9f6087f7828aee 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1,23 +1,44 @@ # mypy: allow-untyped-defs +from typing import Callable, Optional +from typing_extensions import deprecated + +from torch import Tensor +from torch.nn import _reduction as _Reduction, functional as F + from .distance import PairwiseDistance from .module import Module -from .. import functional as F -from .. import _reduction as _Reduction -from torch import Tensor -from typing import Callable, Optional -from typing_extensions import deprecated -__all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', - 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', - 'SmoothL1Loss', 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'MultiLabelSoftMarginLoss', - 'CosineEmbeddingLoss', 'MarginRankingLoss', 'MultiMarginLoss', 'TripletMarginLoss', - 'TripletMarginWithDistanceLoss', 'CTCLoss'] +__all__ = [ + "L1Loss", + "NLLLoss", + "NLLLoss2d", + "PoissonNLLLoss", + "GaussianNLLLoss", + "KLDivLoss", + "MSELoss", + "BCELoss", + "BCEWithLogitsLoss", + "HingeEmbeddingLoss", + "MultiLabelMarginLoss", + "SmoothL1Loss", + "HuberLoss", + "SoftMarginLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", + "CosineEmbeddingLoss", + "MarginRankingLoss", + "MultiMarginLoss", + "TripletMarginLoss", + "TripletMarginWithDistanceLoss", + "CTCLoss", +] + class _Loss(Module): reduction: str - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: super().__init__() if size_average is not None or reduce is not None: self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) @@ -26,9 +47,15 @@ def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> N class _WeightedLoss(_Loss): - def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(size_average, reduce, reduction) - self.register_buffer('weight', weight) + self.register_buffer("weight", weight) self.weight: Optional[Tensor] @@ -92,9 +119,9 @@ class L1Loss(_Loss): >>> output = loss(input, target) >>> output.backward() """ - __constants__ = ['reduction'] + __constants__ = ["reduction"] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: @@ -206,16 +233,28 @@ class NLLLoss(_WeightedLoss): >>> loss = loss_fn(output, target) >>> loss.backward() """ - __constants__ = ['ignore_index', 'reduction'] + __constants__ = ["ignore_index", "reduction"] ignore_index: int - def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, - reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) + return F.nll_loss( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + ) @deprecated( @@ -225,8 +264,14 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: category=FutureWarning, ) class NLLLoss2d(NLLLoss): - def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, - reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, ignore_index, reduce, reduction) @@ -286,21 +331,34 @@ class PoissonNLLLoss(_Loss): - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, the same shape as the input. """ - __constants__ = ['log_input', 'full', 'eps', 'reduction'] + __constants__ = ["log_input", "full", "eps", "reduction"] log_input: bool full: bool eps: float - def __init__(self, log_input: bool = True, full: bool = False, size_average=None, - eps: float = 1e-8, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + log_input: bool = True, + full: bool = False, + size_average=None, + eps: float = 1e-8, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(size_average, reduce, reduction) self.log_input = log_input self.full = full self.eps = eps def forward(self, log_input: Tensor, target: Tensor) -> Tensor: - return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, - eps=self.eps, reduction=self.reduction) + return F.poisson_nll_loss( + log_input, + target, + log_input=self.log_input, + full=self.full, + eps=self.eps, + reduction=self.reduction, + ) class GaussianNLLLoss(_Loss): @@ -369,17 +427,21 @@ class GaussianNLLLoss(_Loss): Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 vol.1, doi: 10.1109/ICNN.1994.374138. """ - __constants__ = ['full', 'eps', 'reduction'] + __constants__ = ["full", "eps", "reduction"] full: bool eps: float - def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None: + def __init__( + self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean" + ) -> None: super().__init__(None, None, reduction) self.full = full self.eps = eps def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: - return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction) + return F.gaussian_nll_loss( + input, target, var, full=self.full, eps=self.eps, reduction=self.reduction + ) class KLDivLoss(_Loss): @@ -463,14 +525,22 @@ class KLDivLoss(_Loss): >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1) >>> output = kl_loss(input, log_target) """ - __constants__ = ['reduction'] - - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', log_target: bool = False) -> None: + __constants__ = ["reduction"] + + def __init__( + self, + size_average=None, + reduce=None, + reduction: str = "mean", + log_target: bool = False, + ) -> None: super().__init__(size_average, reduce, reduction) self.log_target = log_target def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.kl_div(input, target, reduction=self.reduction, log_target=self.log_target) + return F.kl_div( + input, target, reduction=self.reduction, log_target=self.log_target + ) class MSELoss(_Loss): @@ -529,9 +599,9 @@ class MSELoss(_Loss): >>> output = loss(input, target) >>> output.backward() """ - __constants__ = ['reduction'] + __constants__ = ["reduction"] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: @@ -612,13 +682,21 @@ class BCELoss(_WeightedLoss): >>> output = loss(m(input), target) >>> output.backward() """ - __constants__ = ['reduction'] - - def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) + return F.binary_cross_entropy( + input, target, weight=self.weight, reduction=self.reduction + ) class BCEWithLogitsLoss(_Loss): @@ -722,19 +800,29 @@ class BCEWithLogitsLoss(_Loss): >>> output = loss(input, target) >>> output.backward() """ - def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean', - pos_weight: Optional[Tensor] = None) -> None: + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, + ) -> None: super().__init__(size_average, reduce, reduction) - self.register_buffer('weight', weight) - self.register_buffer('pos_weight', pos_weight) + self.register_buffer("weight", weight) + self.register_buffer("pos_weight", pos_weight) self.weight: Optional[Tensor] self.pos_weight: Optional[Tensor] def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.binary_cross_entropy_with_logits(input, target, - self.weight, - pos_weight=self.pos_weight, - reduction=self.reduction) + return F.binary_cross_entropy_with_logits( + input, + target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction, + ) class HingeEmbeddingLoss(_Loss): @@ -786,15 +874,23 @@ class HingeEmbeddingLoss(_Loss): - Target: :math:`(*)`, same shape as the input - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input """ - __constants__ = ['margin', 'reduction'] + __constants__ = ["margin", "reduction"] margin: float - def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + margin: float = 1.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(size_average, reduce, reduction) self.margin = margin def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) + return F.hinge_embedding_loss( + input, target, margin=self.margin, reduction=self.reduction + ) class MultiLabelMarginLoss(_Loss): @@ -852,9 +948,9 @@ class MultiLabelMarginLoss(_Loss): tensor(0.85...) """ - __constants__ = ['reduction'] + __constants__ = ["reduction"] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: @@ -932,9 +1028,11 @@ class SmoothL1Loss(_Loss): - Target: :math:`(*)`, same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. """ - __constants__ = ['reduction'] + __constants__ = ["reduction"] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', beta: float = 1.0) -> None: + def __init__( + self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0 + ) -> None: super().__init__(size_average, reduce, reduction) self.beta = beta @@ -992,9 +1090,9 @@ class HuberLoss(_Loss): - Target: :math:`(*)`, same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. """ - __constants__ = ['reduction', 'delta'] + __constants__ = ["reduction", "delta"] - def __init__(self, reduction: str = 'mean', delta: float = 1.0) -> None: + def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None: super().__init__(reduction=reduction) self.delta = delta @@ -1034,9 +1132,9 @@ class SoftMarginLoss(_Loss): shape as input. """ - __constants__ = ['reduction'] + __constants__ = ["reduction"] - def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: super().__init__(size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: @@ -1174,20 +1272,32 @@ class probabilities only when a single class label per minibatch item is too res >>> output = loss(input, target) >>> output.backward() """ - __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] + __constants__ = ["ignore_index", "reduction", "label_smoothing"] ignore_index: int label_smoothing: float - def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, - reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + ignore_index: int = -100, + reduce=None, + reduction: str = "mean", + label_smoothing: float = 0.0, + ) -> None: super().__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.cross_entropy(input, target, weight=self.weight, - ignore_index=self.ignore_index, reduction=self.reduction, - label_smoothing=self.label_smoothing) + return F.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ) class MultiLabelSoftMarginLoss(_WeightedLoss): @@ -1228,13 +1338,21 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): - Target: :math:`(N, C)`, label targets must have the same shape as the input. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. """ - __constants__ = ['reduction'] - - def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + __constants__ = ["reduction"] + + def __init__( + self, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, reduce, reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) + return F.multilabel_soft_margin_loss( + input, target, weight=self.weight, reduction=self.reduction + ) class CosineEmbeddingLoss(_Loss): @@ -1288,15 +1406,23 @@ class CosineEmbeddingLoss(_Loss): >>> output = loss(input1, input2, target) >>> output.backward() """ - __constants__ = ['margin', 'reduction'] + __constants__ = ["margin", "reduction"] margin: float - def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(size_average, reduce, reduction) self.margin = margin def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: - return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + return F.cosine_embedding_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) class MarginRankingLoss(_Loss): @@ -1345,15 +1471,23 @@ class MarginRankingLoss(_Loss): >>> output = loss(input1, input2, target) >>> output.backward() """ - __constants__ = ['margin', 'reduction'] + __constants__ = ["margin", "reduction"] margin: float - def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + margin: float = 0.0, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(size_average, reduce, reduction) self.margin = margin def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: - return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + return F.margin_ranking_loss( + input1, input2, target, margin=self.margin, reduction=self.reduction + ) class MultiMarginLoss(_WeightedLoss): @@ -1416,16 +1550,23 @@ class MultiMarginLoss(_WeightedLoss): >>> loss(x, y) tensor(0.32...) """ - __constants__ = ['p', 'margin', 'reduction'] + __constants__ = ["p", "margin", "reduction"] margin: float p: int - def __init__(self, p: int = 1, margin: float = 1., weight: Optional[Tensor] = None, size_average=None, - reduce=None, reduction: str = 'mean') -> None: + def __init__( + self, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average=None, + reduce=None, + reduction: str = "mean", + ) -> None: super().__init__(weight, size_average, reduce, reduction) if p != 1 and p != 2: raise ValueError("only p == 1 and p == 2 supported") - if weight is not None and weight.dim() != 1 : + if weight is not None and weight.dim() != 1: raise ValueError( f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead" ) @@ -1433,8 +1574,14 @@ def __init__(self, p: int = 1, margin: float = 1., weight: Optional[Tensor] = No self.margin = margin def forward(self, input: Tensor, target: Tensor) -> Tensor: - return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, - weight=self.weight, reduction=self.reduction) + return F.multi_margin_loss( + input, + target, + p=self.p, + margin=self.margin, + weight=self.weight, + reduction=self.reduction, + ) class TripletMarginLoss(_Loss): @@ -1506,14 +1653,22 @@ class TripletMarginLoss(_Loss): .. _Learning shallow convolutional feature descriptors with triplet losses: http://www.bmva.org/bmvc/2016/papers/paper119/index.html """ - __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction'] + __constants__ = ["margin", "p", "eps", "swap", "reduction"] margin: float p: float eps: float swap: bool - def __init__(self, margin: float = 1.0, p: float = 2., eps: float = 1e-6, swap: bool = False, size_average=None, - reduce=None, reduction: str = 'mean'): + def __init__( + self, + margin: float = 1.0, + p: float = 2.0, + eps: float = 1e-6, + swap: bool = False, + size_average=None, + reduce=None, + reduction: str = "mean", + ): super().__init__(size_average, reduce, reduction) if margin <= 0: raise ValueError( @@ -1525,8 +1680,16 @@ def __init__(self, margin: float = 1.0, p: float = 2., eps: float = 1e-6, swap: self.swap = swap def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: - return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, - eps=self.eps, swap=self.swap, reduction=self.reduction) + return F.triplet_margin_loss( + anchor, + positive, + negative, + margin=self.margin, + p=self.p, + eps=self.eps, + swap=self.swap, + reduction=self.reduction, + ) class TripletMarginWithDistanceLoss(_Loss): @@ -1627,26 +1790,39 @@ class TripletMarginWithDistanceLoss(_Loss): V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: http://www.bmva.org/bmvc/2016/papers/paper119/index.html """ - __constants__ = ['margin', 'swap', 'reduction'] + __constants__ = ["margin", "swap", "reduction"] margin: float swap: bool - def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, - margin: float = 1.0, swap: bool = False, reduction: str = 'mean'): + def __init__( + self, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", + ): super().__init__(size_average=None, reduce=None, reduction=reduction) if margin <= 0: raise ValueError( f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead" ) - self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = \ + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ( distance_function if distance_function is not None else PairwiseDistance() + ) self.margin = margin self.swap = swap def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: - return F.triplet_margin_with_distance_loss(anchor, positive, negative, - distance_function=self.distance_function, - margin=self.margin, swap=self.swap, reduction=self.reduction) + return F.triplet_margin_with_distance_loss( + anchor, + positive, + negative, + distance_function=self.distance_function, + margin=self.margin, + swap=self.swap, + reduction=self.reduction, + ) class CTCLoss(_Loss): @@ -1783,18 +1959,34 @@ class CTCLoss(_Loss): True``. Please see the notes on :doc:`/notes/randomness` for background. """ - __constants__ = ['blank', 'reduction'] + __constants__ = ["blank", "reduction"] blank: int zero_infinity: bool - def __init__(self, blank: int = 0, reduction: str = 'mean', zero_infinity: bool = False): + def __init__( + self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False + ): super().__init__(reduction=reduction) self.blank = blank self.zero_infinity = zero_infinity - def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor: - return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, - self.zero_infinity) + def forward( + self, + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + ) -> Tensor: + return F.ctc_loss( + log_probs, + targets, + input_lengths, + target_lengths, + self.blank, + self.reduction, + self.zero_infinity, + ) + # TODO: L1HingeEmbeddingCriterion # TODO: MSECriterion weight diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f96c705208dc45..ceb6d4b76492e4 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -29,44 +29,55 @@ from torch.utils.hooks import BackwardHook, RemovableHandle -__all__ = ['register_module_forward_pre_hook', 'register_module_forward_hook', - 'register_module_full_backward_pre_hook', 'register_module_backward_hook', - 'register_module_full_backward_hook', 'register_module_buffer_registration_hook', - 'register_module_module_registration_hook', 'register_module_parameter_registration_hook', 'Module'] +__all__ = [ + "register_module_forward_pre_hook", + "register_module_forward_hook", + "register_module_full_backward_pre_hook", + "register_module_backward_hook", + "register_module_full_backward_hook", + "register_module_buffer_registration_hook", + "register_module_module_registration_hook", + "register_module_parameter_registration_hook", + "Module", +] _grad_t = Union[Tuple[Tensor, ...], Tensor] # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be # the type of the subclass, not the looser type of `Module`. -T = TypeVar('T', bound='Module') +T = TypeVar("T", bound="Module") -class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): +class _IncompatibleKeys( + namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), +): def __repr__(self): if not self.missing_keys and not self.unexpected_keys: - return '' + return "" return super().__repr__() __str__ = __repr__ def _addindent(s_, numSpaces): - s = s_.split('\n') + s = s_.split("\n") # don't do anything for single-line stuff if len(s) == 1: return s_ first = s.pop(0) - s = [(numSpaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s + s = [(numSpaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s return s + r"""This tracks hooks common to all modules that are executed immediately before .registering the buffer/module/parameter""" _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() _global_module_registration_hooks: Dict[int, Callable] = OrderedDict() _global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() + class _WrappedHook: def __init__(self, hook: Callable, module: Optional["Module"] = None): self.hook: Callable = hook @@ -99,7 +110,9 @@ def __setstate__(self, state: Dict): if self.with_module: if state["module"] is None: - raise RuntimeError("You are trying to revive the hook of a dead Module!") + raise RuntimeError( + "You are trying to revive the hook of a dead Module!" + ) self.module = weakref.ref(state["module"]) @@ -113,10 +126,12 @@ def __setstate__(self, state: Dict): _global_forward_hooks: Dict[int, Callable] = OrderedDict() _global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() -_EXTRA_STATE_KEY_SUFFIX = '_extra_state' +_EXTRA_STATE_KEY_SUFFIX = "_extra_state" -def register_module_buffer_registration_hook(hook: Callable[..., None]) -> RemovableHandle: +def register_module_buffer_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: r"""Register a buffer registration hook common to all modules. .. warning :: @@ -140,7 +155,9 @@ def register_module_buffer_registration_hook(hook: Callable[..., None]) -> Remov return handle -def register_module_module_registration_hook(hook: Callable[..., None]) -> RemovableHandle: +def register_module_module_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: r"""Register a module registration hook common to all modules. .. warning :: @@ -164,7 +181,9 @@ def register_module_module_registration_hook(hook: Callable[..., None]) -> Remov return handle -def register_module_parameter_registration_hook(hook: Callable[..., None]) -> RemovableHandle: +def register_module_parameter_registration_hook( + hook: Callable[..., None], +) -> RemovableHandle: r"""Register a parameter registration hook common to all modules. .. warning :: @@ -220,7 +239,11 @@ def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHand return handle -def register_module_forward_hook(hook: Callable[..., None], *, always_call: bool = False) -> RemovableHandle: +def register_module_forward_hook( + hook: Callable[..., None], + *, + always_call: bool = False, +) -> RemovableHandle: r"""Register a global forward hook for all the modules. .. warning :: @@ -252,8 +275,9 @@ def register_module_forward_hook(hook: Callable[..., None], *, always_call: bool This hook will be executed before specific module hooks registered with ``register_forward_hook``. """ - handle = RemovableHandle(_global_forward_hooks, - extra_dict=_global_forward_hooks_always_called) + handle = RemovableHandle( + _global_forward_hooks, extra_dict=_global_forward_hooks_always_called + ) _global_forward_hooks[handle.id] = hook if always_call: _global_forward_hooks_always_called[handle.id] = True @@ -261,7 +285,7 @@ def register_module_forward_hook(hook: Callable[..., None], *, always_call: bool def register_module_backward_hook( - hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], ) -> RemovableHandle: r"""Register a backward hook common to all the modules. @@ -277,8 +301,10 @@ def register_module_backward_hook( """ global _global_is_full_backward_hook if _global_is_full_backward_hook is True: - raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " - "global Module hook. Please use only one of them.") + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) _global_is_full_backward_hook = False @@ -288,7 +314,7 @@ def register_module_backward_hook( def register_module_full_backward_pre_hook( - hook: Callable[['Module', _grad_t], Union[None, _grad_t]] + hook: Callable[["Module", _grad_t], Union[None, _grad_t]], ) -> RemovableHandle: r"""Register a backward pre-hook common to all the modules. @@ -315,7 +341,7 @@ def register_module_full_backward_pre_hook( def register_module_full_backward_hook( - hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], ) -> RemovableHandle: r"""Register a backward hook common to all the modules. @@ -338,8 +364,10 @@ def register_module_full_backward_hook( """ global _global_is_full_backward_hook if _global_is_full_backward_hook is False: - raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " - "global Module hook. Please use only one of them.") + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them." + ) _global_is_full_backward_hook = True @@ -362,7 +390,9 @@ def _forward_unimplemented(self, *input: Any) -> None: instead of this since the former takes care of running the registered hooks while the latter silently ignores them. """ - raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "forward" function') + raise NotImplementedError( + f'Module [{type(self).__name__}] is missing the required "forward" function' + ) class Module: @@ -435,9 +465,9 @@ def forward(self, x): _load_state_dict_pre_hooks: Dict[int, Callable] _state_dict_pre_hooks: Dict[int, Callable] _load_state_dict_post_hooks: Dict[int, Callable] - _modules: Dict[str, Optional['Module']] + _modules: Dict[str, Optional["Module"]] call_super_init: bool = False - _compiled_call_impl : Optional[Callable] = None + _compiled_call_impl: Optional[Callable] = None def __init__(self, *args, **kwargs) -> None: """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" @@ -445,12 +475,16 @@ def __init__(self, *args, **kwargs) -> None: # Backward compatibility: no args used to be allowed when call_super_init=False if self.call_super_init is False and bool(kwargs): - raise TypeError(f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" - "") + raise TypeError( + f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" + "" + ) if self.call_super_init is False and bool(args): - raise TypeError(f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" - " given") + raise TypeError( + f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" + " given" + ) """ Calls super().__setattr__('a', a) instead of the typical self.a = a @@ -458,30 +492,32 @@ def __init__(self, *args, **kwargs) -> None: handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. """ - super().__setattr__('training', True) - super().__setattr__('_parameters', OrderedDict()) - super().__setattr__('_buffers', OrderedDict()) - super().__setattr__('_non_persistent_buffers_set', set()) - super().__setattr__('_backward_pre_hooks', OrderedDict()) - super().__setattr__('_backward_hooks', OrderedDict()) - super().__setattr__('_is_full_backward_hook', None) - super().__setattr__('_forward_hooks', OrderedDict()) - super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) - super().__setattr__('_forward_hooks_always_called', OrderedDict()) - super().__setattr__('_forward_pre_hooks', OrderedDict()) - super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) - super().__setattr__('_state_dict_hooks', OrderedDict()) - super().__setattr__('_state_dict_pre_hooks', OrderedDict()) - super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) - super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) - super().__setattr__('_modules', OrderedDict()) + super().__setattr__("training", True) + super().__setattr__("_parameters", OrderedDict()) + super().__setattr__("_buffers", OrderedDict()) + super().__setattr__("_non_persistent_buffers_set", set()) + super().__setattr__("_backward_pre_hooks", OrderedDict()) + super().__setattr__("_backward_hooks", OrderedDict()) + super().__setattr__("_is_full_backward_hook", None) + super().__setattr__("_forward_hooks", OrderedDict()) + super().__setattr__("_forward_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_forward_hooks_always_called", OrderedDict()) + super().__setattr__("_forward_pre_hooks", OrderedDict()) + super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict()) + super().__setattr__("_state_dict_hooks", OrderedDict()) + super().__setattr__("_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) + super().__setattr__("_load_state_dict_post_hooks", OrderedDict()) + super().__setattr__("_modules", OrderedDict()) if self.call_super_init: super().__init__(*args, **kwargs) forward: Callable[..., Any] = _forward_unimplemented - def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: + def register_buffer( + self, name: str, tensor: Optional[Tensor], persistent: bool = True + ) -> None: r"""Add a buffer to the module. This is typically used to register a buffer that should not to be @@ -513,21 +549,23 @@ def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool if persistent is False and isinstance(self, torch.jit.ScriptModule): raise RuntimeError("ScriptModule does not support non-persistent buffers") - if '_buffers' not in self.__dict__: - raise AttributeError( - "cannot assign buffer before Module.__init__() call") + if "_buffers" not in self.__dict__: + raise AttributeError("cannot assign buffer before Module.__init__() call") elif not isinstance(name, str): - raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}") - elif '.' in name: - raise KeyError("buffer name can't contain \".\"") - elif name == '': - raise KeyError("buffer name can't be empty string \"\"") + raise TypeError( + f"buffer name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('buffer name can\'t contain "."') + elif name == "": + raise KeyError('buffer name can\'t be empty string ""') elif hasattr(self, name) and name not in self._buffers: raise KeyError(f"attribute '{name}' already exists") elif tensor is not None and not isinstance(tensor, torch.Tensor): - raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " - "(torch Tensor or None required)" - ) + raise TypeError( + f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " + "(torch Tensor or None required)" + ) else: for hook in _global_buffer_registration_hooks.values(): output = hook(self, name, tensor) @@ -552,31 +590,36 @@ def register_parameter(self, name: str, param: Optional[Parameter]) -> None: are ignored. If ``None``, the parameter is **not** included in the module's :attr:`state_dict`. """ - if '_parameters' not in self.__dict__: + if "_parameters" not in self.__dict__: raise AttributeError( - "cannot assign parameter before Module.__init__() call") + "cannot assign parameter before Module.__init__() call" + ) elif not isinstance(name, str): - raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}") - elif '.' in name: - raise KeyError("parameter name can't contain \".\"") - elif name == '': - raise KeyError("parameter name can't be empty string \"\"") + raise TypeError( + f"parameter name should be a string. Got {torch.typename(name)}" + ) + elif "." in name: + raise KeyError('parameter name can\'t contain "."') + elif name == "": + raise KeyError('parameter name can\'t be empty string ""') elif hasattr(self, name) and name not in self._parameters: raise KeyError(f"attribute '{name}' already exists") if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): - raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " - "(torch.nn.Parameter or None required)" - ) + raise TypeError( + f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + "(torch.nn.Parameter or None required)" + ) elif param.grad_fn: raise ValueError( f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " f"parameters must be created explicitly. To express '{name}' " "as a function of another Tensor, compute the value in " - "the forward() method.") + "the forward() method." + ) else: for hook in _global_parameter_registration_hooks.values(): output = hook(self, name, param) @@ -584,7 +627,7 @@ def register_parameter(self, name: str, param: Optional[Parameter]) -> None: param = output self._parameters[name] = param - def add_module(self, name: str, module: Optional['Module']) -> None: + def add_module(self, name: str, module: Optional["Module"]) -> None: r"""Add a child module to the current module. The module can be accessed as an attribute using the given name. @@ -597,20 +640,22 @@ def add_module(self, name: str, module: Optional['Module']) -> None: if not isinstance(module, Module) and module is not None: raise TypeError(f"{torch.typename(module)} is not a Module subclass") elif not isinstance(name, str): - raise TypeError(f"module name should be a string. Got {torch.typename(name)}") + raise TypeError( + f"module name should be a string. Got {torch.typename(name)}" + ) elif hasattr(self, name) and name not in self._modules: raise KeyError(f"attribute '{name}' already exists") - elif '.' in name: - raise KeyError(f"module name can't contain \".\", got: {name}") - elif name == '': - raise KeyError("module name can't be empty string \"\"") + elif "." in name: + raise KeyError(f'module name can\'t contain ".", got: {name}') + elif name == "": + raise KeyError('module name can\'t be empty string ""') for hook in _global_module_registration_hooks.values(): output = hook(self, name, module) if output is not None: module = output self._modules[name] = module - def register_module(self, name: str, module: Optional['Module']) -> None: + def register_module(self, name: str, module: Optional["Module"]) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module) @@ -667,16 +712,15 @@ def get_submodule(self, target: str) -> "Module": mod: torch.nn.Module = self for item in atoms: - if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no " - "attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " - "an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod @@ -705,14 +749,14 @@ def get_parameter(self, target: str) -> "Parameter": mod: torch.nn.Module = self.get_submodule(module_path) if not hasattr(mod, param_name): - raise AttributeError(mod._get_name() + " has no attribute `" - + param_name + "`") + raise AttributeError( + mod._get_name() + " has no attribute `" + param_name + "`" + ) param: torch.nn.Parameter = getattr(mod, param_name) if not isinstance(param, torch.nn.Parameter): - raise AttributeError("`" + param_name + "` is not an " - "nn.Parameter") + raise AttributeError("`" + param_name + "` is not an " "nn.Parameter") return param @@ -741,8 +785,9 @@ def get_buffer(self, target: str) -> "Tensor": mod: torch.nn.Module = self.get_submodule(module_path) if not hasattr(mod, buffer_name): - raise AttributeError(mod._get_name() + " has no attribute `" - + buffer_name + "`") + raise AttributeError( + mod._get_name() + " has no attribute `" + buffer_name + "`" + ) buffer: torch.Tensor = getattr(mod, buffer_name) @@ -769,7 +814,8 @@ def get_extra_state(self) -> Any: raise RuntimeError( "Reached a code path in Module.get_extra_state() that should never be called. " "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " - "to report this bug.") + "to report this bug." + ) def set_extra_state(self, state: Any) -> None: """Set extra state contained in the loaded `state_dict`. @@ -785,7 +831,8 @@ def set_extra_state(self, state: Any) -> None: raise RuntimeError( "Reached a code path in Module.set_extra_state() that should never be called. " "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " - "to report this bug.") + "to report this bug." + ) def _apply(self, fn, recurse=True): if recurse: @@ -806,7 +853,9 @@ def compute_should_use_set_data(tensor, tensor_applied): else: return False - should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + should_use_swap_tensors = ( + torch.__future__.get_swap_module_params_on_conversion() + ) for key, param in self._parameters.items(): if param is None: @@ -819,7 +868,9 @@ def compute_should_use_set_data(tensor, tensor_applied): p_should_use_set_data = compute_should_use_set_data(param, param_applied) # subclasses may have multiple child tensors so we need to use swap_tensors - p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + p_should_use_swap_tensors = ( + should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + ) param_grad = param.grad if p_should_use_swap_tensors: @@ -828,12 +879,16 @@ def compute_should_use_set_data(tensor, tensor_applied): # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. # Decrement use count of the gradient by setting to None param.grad = None - param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad) + param_applied = torch.nn.Parameter( + param_applied, requires_grad=param.requires_grad + ) torch.utils.swap_tensors(param, param_applied) except Exception as e: if param_grad is not None: param.grad = param_grad - raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}" + ) from e out_param = param elif p_should_use_set_data: param.data = param_applied @@ -847,20 +902,26 @@ def compute_should_use_set_data(tensor, tensor_applied): if param_grad is not None: with torch.no_grad(): grad_applied = fn(param_grad) - g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied) + g_should_use_set_data = compute_should_use_set_data( + param_grad, grad_applied + ) if p_should_use_swap_tensors: grad_applied.requires_grad_(param_grad.requires_grad) try: torch.utils.swap_tensors(param_grad, grad_applied) except Exception as e: - raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e + raise RuntimeError( + f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" + ) from e out_param.grad = param_grad elif g_should_use_set_data: assert out_param.grad is not None out_param.grad.data = grad_applied else: assert param_grad.is_leaf - out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad) + out_param.grad = grad_applied.requires_grad_( + param_grad.requires_grad + ) for key, buf in self._buffers.items(): if buf is not None: @@ -868,7 +929,7 @@ def compute_should_use_set_data(tensor, tensor_applied): return self - def apply(self: T, fn: Callable[['Module'], None]) -> T: + def apply(self: T, fn: Callable[["Module"], None]) -> T: r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model @@ -1035,7 +1096,9 @@ def bfloat16(self: T) -> T: """ return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) - def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T: + def to_empty( + self: T, *, device: Optional[DeviceLikeType], recurse: bool = True + ) -> T: r"""Move the parameters and buffers to the specified device without copying storage. Args: @@ -1047,11 +1110,17 @@ def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) Returns: Module: self """ - return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse) + return self._apply( + lambda t: torch.empty_like(t, device=device), recurse=recurse + ) @overload - def to(self, device: Optional[DeviceLikeType] = ..., dtype: Optional[dtype] = ..., - non_blocking: bool = ...) -> Self: + def to( + self, + device: Optional[DeviceLikeType] = ..., + dtype: Optional[dtype] = ..., + non_blocking: bool = ..., + ) -> Self: ... @overload @@ -1148,18 +1217,23 @@ def to(self, *args, **kwargs): [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) """ - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) if dtype is not None: if not (dtype.is_floating_point or dtype.is_complex): - raise TypeError('nn.Module.to only accepts floating point or complex ' - f'dtypes, but got desired dtype={dtype}') + raise TypeError( + "nn.Module.to only accepts floating point or complex " + f"dtypes, but got desired dtype={dtype}" + ) if dtype.is_complex: warnings.warn( "Complex modules are a new feature under active development whose design may change, " "and some modules might not work as expected when using complex tensors as parameters or buffers. " "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " - "if a complex module does not work as expected.") + "if a complex module does not work as expected." + ) def convert(t): try: @@ -1236,7 +1310,7 @@ def register_full_backward_pre_hook( return handle def register_backward_hook( - self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] + self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]] ) -> RemovableHandle: r"""Register a backward hook on the module. @@ -1250,8 +1324,10 @@ def register_backward_hook( """ if self._is_full_backward_hook is True: - raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " - "single Module. Please use only one of them.") + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) self._is_full_backward_hook = False @@ -1308,8 +1384,10 @@ def register_full_backward_hook( """ if self._is_full_backward_hook is False: - raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " - "single Module. Please use only one of them.") + raise RuntimeError( + "Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them." + ) self._is_full_backward_hook = True @@ -1326,15 +1404,15 @@ def _get_backward_hooks(self): backward hooks. """ full_backward_hooks: List[Callable] = [] - if (_global_is_full_backward_hook is True): + if _global_is_full_backward_hook is True: full_backward_hooks += _global_backward_hooks.values() - if (self._is_full_backward_hook is True): + if self._is_full_backward_hook is True: full_backward_hooks += self._backward_hooks.values() non_full_backward_hooks: List[Callable] = [] - if (_global_is_full_backward_hook is False): + if _global_is_full_backward_hook is False: non_full_backward_hooks += _global_backward_hooks.values() - if (self._is_full_backward_hook is False): + if self._is_full_backward_hook is False: non_full_backward_hooks += self._backward_hooks.values() return full_backward_hooks, non_full_backward_hooks @@ -1348,7 +1426,10 @@ def _get_backward_pre_hooks(self): def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): if not isinstance(result, torch.Tensor): - if not (isinstance(result, tuple) and all(isinstance(r, torch.Tensor) for r in result)): + if not ( + isinstance(result, tuple) + and all(isinstance(r, torch.Tensor) for r in result) + ): warnings.warn( "Using non-full backward hooks on a Module that does not return a " "single Tensor or a tuple of Tensors is deprecated and will be removed " @@ -1362,7 +1443,10 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): result = (result,) if not isinstance(inputs, torch.Tensor): - if not (isinstance(inputs, tuple) and all(isinstance(i, torch.Tensor) for i in inputs)): + if not ( + isinstance(inputs, tuple) + and all(isinstance(i, torch.Tensor) for i in inputs) + ): warnings.warn( "Using non-full backward hooks on a Module that does not take as input a " "single Tensor or a tuple of Tensors is deprecated and will be removed " @@ -1377,7 +1461,9 @@ def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): # At this point we are sure that inputs and result are tuple of Tensors out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} - if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): + if len(out_grad_fn) == 0 or ( + len(out_grad_fn) == 1 and grad_fn not in out_grad_fn + ): warnings.warn( "Using a non-full backward hook when outputs are nested in python data structure " "is deprecated and will be removed in future versions. This hook will be missing " @@ -1413,7 +1499,10 @@ def register_forward_pre_hook( self, hook: Union[ Callable[[T, Tuple[Any, ...]], Optional[Any]], - Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + Callable[ + [T, Tuple[Any, ...], Dict[str, Any]], + Optional[Tuple[Any, Dict[str, Any]]], + ], ], *, prepend: bool = False, @@ -1462,8 +1551,7 @@ def register_forward_pre_hook( ``handle.remove()`` """ handle = RemovableHandle( - self._forward_pre_hooks, - extra_dict=self._forward_pre_hooks_with_kwargs + self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs ) self._forward_pre_hooks[handle.id] = hook if with_kwargs: @@ -1528,7 +1616,10 @@ def register_forward_hook( """ handle = RemovableHandle( self._forward_hooks, - extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called], + extra_dict=[ + self._forward_hooks_with_kwargs, + self._forward_hooks_always_called, + ], ) self._forward_hooks[handle.id] = hook if with_kwargs: @@ -1566,12 +1657,21 @@ def _wrapped_call_impl(self, *args, **kwargs): return self._call_impl(*args, **kwargs) def _call_impl(self, *args, **kwargs): - forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + forward_call = ( + self._slow_forward if torch._C._get_tracing_state() else self.forward + ) # If we don't have any hooks, we want to skip the rest of the logic in # this function, and just call forward. - if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks - or _global_backward_pre_hooks or _global_backward_hooks - or _global_forward_hooks or _global_forward_pre_hooks): + if not ( + self._backward_hooks + or self._backward_pre_hooks + or self._forward_hooks + or self._forward_pre_hooks + or _global_backward_pre_hooks + or _global_backward_hooks + or _global_forward_hooks + or _global_forward_pre_hooks + ): return forward_call(*args, **kwargs) try: @@ -1584,7 +1684,10 @@ def _call_impl(self, *args, **kwargs): backward_pre_hooks = self._get_backward_pre_hooks() if self._backward_hooks or _global_backward_hooks: - full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + ( + full_backward_hooks, + non_full_backward_hooks, + ) = self._get_backward_hooks() if _global_forward_pre_hooks or self._forward_pre_hooks: for hook_id, hook in ( @@ -1594,7 +1697,10 @@ def _call_impl(self, *args, **kwargs): if hook_id in self._forward_pre_hooks_with_kwargs: args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] if args_kwargs_result is not None: - if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: + if ( + isinstance(args_kwargs_result, tuple) + and len(args_kwargs_result) == 2 + ): args, kwargs = args_kwargs_result else: raise RuntimeError( @@ -1620,7 +1726,10 @@ def _call_impl(self, *args, **kwargs): *self._forward_hooks.items(), ): # mark that always called hook is run - if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: + if ( + hook_id in self._forward_hooks_always_called + or hook_id in _global_forward_hooks_always_called + ): called_always_called_hooks.add(hook_id) if hook_id in self._forward_hooks_with_kwargs: @@ -1633,9 +1742,11 @@ def _call_impl(self, *args, **kwargs): if bw_hook: if not isinstance(result, (torch.Tensor, tuple)): - warnings.warn("For backward hooks to be called," - " module output should be a Tensor or a tuple of Tensors" - f" but received {type(result)}") + warnings.warn( + "For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}" + ) result = bw_hook.setup_output_hook(result) # Handle the non-full backward hooks @@ -1643,7 +1754,9 @@ def _call_impl(self, *args, **kwargs): var = result while not isinstance(var, torch.Tensor): if isinstance(var, dict): - var = next(v for v in var.values() if isinstance(v, torch.Tensor)) + var = next( + v for v in var.values() if isinstance(v, torch.Tensor) + ) else: var = var[0] grad_fn = var.grad_fn @@ -1665,8 +1778,10 @@ def _call_impl(self, *args, **kwargs): if hook_result is not None: result = hook_result except Exception as e: - warnings.warn("global module forward hook with ``always_call=True`` raised an exception " - f"that was silenced as another error was raised in forward: {str(e)}") + warnings.warn( + "global module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}" + ) continue for hook_id, hook in self._forward_hooks.items(): @@ -1679,13 +1794,15 @@ def _call_impl(self, *args, **kwargs): if hook_result is not None: result = hook_result except Exception as e: - warnings.warn("module forward hook with ``always_call=True`` raised an exception " - f"that was silenced as another error was raised in forward: {str(e)}") + warnings.warn( + "module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}" + ) continue # raise exception raised in try block raise - __call__ : Callable[..., Any] = _wrapped_call_impl + __call__: Callable[..., Any] = _wrapped_call_impl def __getstate__(self): state = self.__dict__.copy() @@ -1696,27 +1813,27 @@ def __setstate__(self, state): self.__dict__.update(state) # Support loading old checkpoints that don't have the following attrs: - if '_forward_pre_hooks' not in self.__dict__: + if "_forward_pre_hooks" not in self.__dict__: self._forward_pre_hooks = OrderedDict() - if '_forward_pre_hooks_with_kwargs' not in self.__dict__: + if "_forward_pre_hooks_with_kwargs" not in self.__dict__: self._forward_pre_hooks_with_kwargs = OrderedDict() - if '_forward_hooks_with_kwargs' not in self.__dict__: + if "_forward_hooks_with_kwargs" not in self.__dict__: self._forward_hooks_with_kwargs = OrderedDict() - if '_forward_hooks_always_called' not in self.__dict__: + if "_forward_hooks_always_called" not in self.__dict__: self._forward_hooks_always_called = OrderedDict() - if '_state_dict_hooks' not in self.__dict__: + if "_state_dict_hooks" not in self.__dict__: self._state_dict_hooks = OrderedDict() - if '_state_dict_pre_hooks' not in self.__dict__: + if "_state_dict_pre_hooks" not in self.__dict__: self._state_dict_pre_hooks = OrderedDict() - if '_load_state_dict_pre_hooks' not in self.__dict__: + if "_load_state_dict_pre_hooks" not in self.__dict__: self._load_state_dict_pre_hooks = OrderedDict() - if '_load_state_dict_post_hooks' not in self.__dict__: + if "_load_state_dict_post_hooks" not in self.__dict__: self._load_state_dict_post_hooks = OrderedDict() - if '_non_persistent_buffers_set' not in self.__dict__: + if "_non_persistent_buffers_set" not in self.__dict__: self._non_persistent_buffers_set = set() - if '_is_full_backward_hook' not in self.__dict__: + if "_is_full_backward_hook" not in self.__dict__: self._is_full_backward_hook = None - if '_backward_pre_hooks' not in self.__dict__: + if "_backward_pre_hooks" not in self.__dict__: self._backward_pre_hooks = OrderedDict() # On the return type: @@ -1727,21 +1844,23 @@ def __setstate__(self, state): # See full discussion on the problems with returning `Union` here # https://github.com/microsoft/pyright/issues/4213 def __getattr__(self, name: str) -> Any: - if '_parameters' in self.__dict__: - _parameters = self.__dict__['_parameters'] + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] if name in _parameters: return _parameters[name] - if '_buffers' in self.__dict__: - _buffers = self.__dict__['_buffers'] + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] if name in _buffers: return _buffers[name] - if '_modules' in self.__dict__: - modules = self.__dict__['_modules'] + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] if name in modules: return modules[name] - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) - def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: + def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: def remove_from(*dicts_or_sets): for d in dicts_or_sets: if name in d: @@ -1750,26 +1869,39 @@ def remove_from(*dicts_or_sets): else: d.discard(name) - params = self.__dict__.get('_parameters') + params = self.__dict__.get("_parameters") if isinstance(value, Parameter): if params is None: raise AttributeError( - "cannot assign parameters before Module.__init__() call") - remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) + "cannot assign parameters before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._buffers, + self._modules, + self._non_persistent_buffers_set, + ) self.register_parameter(name, value) elif params is not None and name in params: if value is not None: - raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' " - "(torch.nn.Parameter or None expected)" - ) + raise TypeError( + f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + "(torch.nn.Parameter or None expected)" + ) self.register_parameter(name, value) else: - modules = self.__dict__.get('_modules') + modules = self.__dict__.get("_modules") if isinstance(value, Module): if modules is None: raise AttributeError( - "cannot assign module before Module.__init__() call") - remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) + "cannot assign module before Module.__init__() call" + ) + remove_from( + self.__dict__, + self._parameters, + self._buffers, + self._non_persistent_buffers_set, + ) for hook in _global_module_registration_hooks.values(): output = hook(self, name, value) if output is not None: @@ -1777,21 +1909,23 @@ def remove_from(*dicts_or_sets): modules[name] = value elif modules is not None and name in modules: if value is not None: - raise TypeError(f"cannot assign '{torch.typename(value)}' as child module '{name}' " - "(torch.nn.Module or None expected)" - ) + raise TypeError( + f"cannot assign '{torch.typename(value)}' as child module '{name}' " + "(torch.nn.Module or None expected)" + ) for hook in _global_module_registration_hooks.values(): output = hook(self, name, value) if output is not None: value = output modules[name] = value else: - buffers = self.__dict__.get('_buffers') + buffers = self.__dict__.get("_buffers") if buffers is not None and name in buffers: if value is not None and not isinstance(value, torch.Tensor): - raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " - "(torch.Tensor or None expected)" - ) + raise TypeError( + f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + "(torch.Tensor or None expected)" + ) for hook in _global_buffer_registration_hooks.values(): output = hook(self, name, value) if output is not None: @@ -1858,15 +1992,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach() extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", Module.get_extra_state) + is not Module.get_extra_state + ): destination[extra_state_key] = self.get_extra_state() # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. - T_destination = TypeVar('T_destination', bound=Dict[str, Any]) + T_destination = TypeVar("T_destination", bound=Dict[str, Any]) @overload - def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + def state_dict( + self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ... + ) -> T_destination: ... @overload @@ -1875,7 +2014,7 @@ def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, A # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. # Also remove the logic for arg parsing together. - def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): r"""Return a dictionary containing references to the whole state of the module. Both parameters and persistent buffers (e.g. running averages) are @@ -1931,7 +2070,7 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): ) if destination is None: destination = args[0] - if len(args) > 1 and prefix == '': + if len(args) > 1 and prefix == "": prefix = args[1] if len(args) > 2 and keep_vars is False: keep_vars = args[2] @@ -1949,7 +2088,11 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): self._save_to_state_dict(destination, prefix, keep_vars) for name, module in self._modules.items(): if module is not None: - module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + module.state_dict( + destination=destination, + prefix=prefix + name + ".", + keep_vars=keep_vars, + ) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) if hook_result is not None: @@ -1974,7 +2117,9 @@ def _register_load_state_dict_pre_hook(self, hook, with_module=False): instance to the hook as the first parameter. """ handle = RemovableHandle(self._load_state_dict_pre_hooks) - self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( + hook, self if with_module else None + ) return handle def register_load_state_dict_post_hook(self, hook): @@ -2006,8 +2151,16 @@ def register_load_state_dict_post_hook(self, hook): self._load_state_dict_post_hooks[handle.id] = hook return handle - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule @@ -2044,10 +2197,24 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, :meth:`~torch.nn.Module.load_state_dict` """ for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) - persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) local_state = {k: v for k, v in local_name_params if v is not None} assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() @@ -2057,10 +2224,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: input_param = state_dict[key] if not torch.overrides.is_tensor_like(input_param): - error_msgs.append(f'While copying the parameter named "{key}", ' - 'expected torch.Tensor or Tensor-like object from checkpoint but ' - f'received {type(input_param)}' - ) + error_msgs.append( + f'While copying the parameter named "{key}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + f"received {type(input_param)}" + ) continue # This is used to avoid copying uninitialized parameters into @@ -2068,40 +2236,63 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # in such case, it will error when accessing the .shape attribute. is_param_lazy = torch.nn.parameter.is_lazy(param) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): input_param = input_param[0] if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint - error_msgs.append(f'size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, ' - f'the shape in current model is {param.shape}.') + error_msgs.append( + f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " + f"the shape in current model is {param.shape}." + ) continue - if param.is_meta and not input_param.is_meta and not assign_to_params_buffers: - warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' - 'parameter in the current model, which is a no-op. (Did you mean to ' - 'pass `assign=True` to assign items in the state dictionary to their ' - 'corresponding key in the module instead of copying them in place?)') + if ( + param.is_meta + and not input_param.is_meta + and not assign_to_params_buffers + ): + warnings.warn( + f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " + "parameter in the current model, which is a no-op. (Did you mean to " + "pass `assign=True` to assign items in the state dictionary to their " + "corresponding key in the module instead of copying them in place?)" + ) try: with torch.no_grad(): if use_swap_tensors: - new_input_param = param.module_load(input_param, assign=assign_to_params_buffers) - if id(new_input_param) == id(input_param) or id(new_input_param) == id(param): - raise RuntimeError("module_load returned one of self or other, please .detach() " - "the result if returning one of the inputs in module_load") - if (isinstance(param, torch.nn.Parameter)): + new_input_param = param.module_load( + input_param, assign=assign_to_params_buffers + ) + if id(new_input_param) == id(input_param) or id( + new_input_param + ) == id(param): + raise RuntimeError( + "module_load returned one of self or other, please .detach() " + "the result if returning one of the inputs in module_load" + ) + if isinstance(param, torch.nn.Parameter): if not isinstance(new_input_param, torch.nn.Parameter): - new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad) + new_input_param = torch.nn.Parameter( + new_input_param, + requires_grad=param.requires_grad, + ) else: new_input_param.requires_grad_(param.requires_grad) torch.utils.swap_tensors(param, new_input_param) del new_input_param elif assign_to_params_buffers: # Shape checks are already done above - if (isinstance(param, torch.nn.Parameter)): + if isinstance(param, torch.nn.Parameter): if not isinstance(input_param, torch.nn.Parameter): - input_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) + input_param = torch.nn.Parameter( + input_param, requires_grad=param.requires_grad + ) else: input_param.requires_grad_(param.requires_grad) setattr(self, name, input_param) @@ -2109,16 +2300,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, param.copy_(input_param) except Exception as ex: action = "swapping" if use_swap_tensors else "copying" - error_msgs.append(f'While {action} the parameter named "{key}", ' - f'whose dimensions in the model are {param.size()} and ' - f'whose dimensions in the checkpoint are {input_param.size()}, ' - f'an exception occurred : {ex.args}.' - ) + error_msgs.append( + f'While {action} the parameter named "{key}", ' + f"whose dimensions in the model are {param.size()} and " + f"whose dimensions in the checkpoint are {input_param.size()}, " + f"an exception occurred : {ex.args}." + ) elif strict: missing_keys.append(key) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if ( + getattr(self.__class__, "set_extra_state", Module.set_extra_state) + is not Module.set_extra_state + ): if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: @@ -2129,7 +2324,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):].split(".", 1) + input_name = key[len(prefix) :].split(".", 1) # Must be Module if it have attributes if len(input_name) > 1: if input_name[0] not in self._modules: @@ -2137,8 +2332,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, elif input_name[0] not in local_state: unexpected_keys.append(key) - def load_state_dict(self, state_dict: Mapping[str, Any], - strict: bool = True, assign: bool = False): + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then @@ -2176,29 +2372,42 @@ def load_state_dict(self, state_dict: Mapping[str, Any], ``RuntimeError``. """ if not isinstance(state_dict, Mapping): - raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") + raise TypeError( + f"Expected state_dict to be dict-like, got {type(state_dict)}." + ) missing_keys: List[str] = [] unexpected_keys: List[str] = [] error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = OrderedDict(state_dict) if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict state_dict._metadata = metadata # type: ignore[attr-defined] - def load(module, local_state_dict, prefix=''): + def load(module, local_state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) if assign: - local_metadata['assign_to_params_buffers'] = assign + local_metadata["assign_to_params_buffers"] = assign module._load_from_state_dict( - local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) for name, child in module._modules.items(): if child is not None: - child_prefix = prefix + name + '.' - child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + child_prefix = prefix + name + "." + child_state_dict = { + k: v + for k, v in local_state_dict.items() + if k.startswith(child_prefix) + } load(child, child_state_dict, child_prefix) # noqa: F821 # Note that the hook can modify missing_keys and unexpected_keys. @@ -2217,22 +2426,37 @@ def load(module, local_state_dict, prefix=''): if strict: if len(unexpected_keys) > 0: error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join(f'"{k}"' for k in unexpected_keys))) + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in unexpected_keys) + ), + ) if len(missing_keys) > 0: error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format( - ', '.join(f'"{k}"' for k in missing_keys))) + 0, + "Missing key(s) in state_dict: {}. ".format( + ", ".join(f'"{k}"' for k in missing_keys) + ), + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) return _IncompatibleKeys(missing_keys, unexpected_keys) - def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True): + def _named_members( + self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True + ): r"""Help yield various names + members of modules.""" memo = set() - modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)] + modules = ( + self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) + if recurse + else [(prefix, self)] + ) for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: @@ -2240,7 +2464,7 @@ def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplica continue if remove_duplicate: memo.add(v) - name = module_prefix + ('.' if module_prefix else '') + k + name = module_prefix + ("." if module_prefix else "") + k yield name, v def parameters(self, recurse: bool = True) -> Iterator[Parameter]: @@ -2269,10 +2493,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: yield param def named_parameters( - self, - prefix: str = '', - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. @@ -2297,7 +2518,10 @@ def named_parameters( """ gen = self._named_members( lambda module: module._parameters.items(), - prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) yield from gen def buffers(self, recurse: bool = True) -> Iterator[Tensor]: @@ -2323,7 +2547,9 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: for _, buf in self.named_buffers(recurse=recurse): yield buf - def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Tensor]]: r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. Args: @@ -2346,10 +2572,13 @@ def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate """ gen = self._named_members( lambda module: module._buffers.items(), - prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + prefix=prefix, + recurse=recurse, + remove_duplicate=remove_duplicate, + ) yield from gen - def children(self) -> Iterator['Module']: + def children(self) -> Iterator["Module"]: r"""Return an iterator over immediate children modules. Yields: @@ -2358,7 +2587,7 @@ def children(self) -> Iterator['Module']: for name, module in self.named_children(): yield module - def named_children(self) -> Iterator[Tuple[str, 'Module']]: + def named_children(self) -> Iterator[Tuple[str, "Module"]]: r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: @@ -2378,7 +2607,7 @@ def named_children(self) -> Iterator[Tuple[str, 'Module']]: memo.add(module) yield name, module - def modules(self) -> Iterator['Module']: + def modules(self) -> Iterator["Module"]: r"""Return an iterator over all modules in the network. Yields: @@ -2405,7 +2634,12 @@ def modules(self) -> Iterator['Module']: for _, module in self.named_modules(): yield module - def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + def named_modules( + self, + memo: Optional[Set["Module"]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Args: @@ -2444,8 +2678,10 @@ def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', for name, module in self._modules.items(): if module is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name - yield from module.named_modules(memo, submodule_prefix, remove_duplicate) + submodule_prefix = prefix + ("." if prefix else "") + name + yield from module.named_modules( + memo, submodule_prefix, remove_duplicate + ) def train(self: T, mode: bool = True) -> T: r"""Set the module in training mode. @@ -2519,12 +2755,13 @@ def zero_grad(self, set_to_none: bool = True) -> None: set_to_none (bool): instead of setting to zero, set the grads to None. See :meth:`torch.optim.Optimizer.zero_grad` for details. """ - if getattr(self, '_is_replica', False): + if getattr(self, "_is_replica", False): warnings.warn( "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " "The parameters are copied (in a differentiable manner) from the original module. " "This means they are not leaf nodes in autograd and so don't accumulate gradients. " - "If you need gradients in your forward method, consider using autograd.grad instead.") + "If you need gradients in your forward method, consider using autograd.grad instead." + ) for p in self.parameters(): if p.grad is not None: @@ -2551,7 +2788,7 @@ def extra_repr(self) -> str: this method in your own modules. Both single-line and multi-line strings are acceptable. """ - return '' + return "" def __repr__(self): # We treat the extra repr like the sub-module, one item per line @@ -2559,23 +2796,23 @@ def __repr__(self): extra_repr = self.extra_repr() # empty string will be split into list [''] if extra_repr: - extra_lines = extra_repr.split('\n') + extra_lines = extra_repr.split("\n") child_lines = [] for key, module in self._modules.items(): mod_str = repr(module) mod_str = _addindent(mod_str, 2) - child_lines.append('(' + key + '): ' + mod_str) + child_lines.append("(" + key + "): " + mod_str) lines = extra_lines + child_lines - main_str = self._get_name() + '(' + main_str = self._get_name() + "(" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: - main_str += '\n ' + '\n '.join(lines) + '\n' + main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ')' + main_str += ")" return main_str def __dir__(self): diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index d503409d53a133..ea37166c11bd65 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -1,16 +1,18 @@ # mypy: allow-untyped-defs -import torch import numbers +from typing import List, Optional, Tuple, Union + +import torch +from torch import Size, Tensor +from torch.nn import functional as F, init from torch.nn.parameter import Parameter -from .module import Module + from ._functions import CrossMapLRN2d as _cross_map_lrn2d -from .. import functional as F -from .. import init +from .module import Module + -from torch import Tensor, Size -from typing import Union, List, Optional, Tuple +__all__ = ["LocalResponseNorm", "CrossMapLRN2d", "LayerNorm", "GroupNorm", "RMSNorm"] -__all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm', 'RMSNorm'] class LocalResponseNorm(Module): r"""Applies local response normalization over an input signal. @@ -42,13 +44,15 @@ class LocalResponseNorm(Module): """ - __constants__ = ['size', 'alpha', 'beta', 'k'] + __constants__ = ["size", "alpha", "beta", "k"] size: int alpha: float beta: float k: float - def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None: + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0 + ) -> None: super().__init__() self.size = size self.alpha = alpha @@ -56,11 +60,10 @@ def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float self.k = k def forward(self, input: Tensor) -> Tensor: - return F.local_response_norm(input, self.size, self.alpha, self.beta, - self.k) + return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) def extra_repr(self): - return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) class CrossMapLRN2d(Module): @@ -69,7 +72,9 @@ class CrossMapLRN2d(Module): beta: float k: float - def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None: + def __init__( + self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1 + ) -> None: super().__init__() self.size = size self.alpha = alpha @@ -77,11 +82,10 @@ def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float self.k = k def forward(self, input: Tensor) -> Tensor: - return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, - self.k) + return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta, self.k) def extra_repr(self) -> str: - return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) + return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__) _shape_t = Union[int, List[int], Size] @@ -165,14 +169,21 @@ class LayerNorm(Module): """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool - def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, - bias: bool = True, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment @@ -181,14 +192,18 @@ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_af self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: - self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) if bias: - self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.bias = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) else: - self.register_parameter('weight', None) - self.register_parameter('bias', None) + self.register_parameter("weight", None) + self.register_parameter("bias", None) self.reset_parameters() @@ -200,11 +215,14 @@ def reset_parameters(self) -> None: def forward(self, input: Tensor) -> Tensor: return F.layer_norm( - input, self.normalized_shape, self.weight, self.bias, self.eps) + input, self.normalized_shape, self.weight, self.bias, self.eps + ) def extra_repr(self) -> str: - return '{normalized_shape}, eps={eps}, ' \ - 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) class GroupNorm(Module): @@ -253,18 +271,25 @@ class GroupNorm(Module): >>> output = m(input) """ - __constants__ = ['num_groups', 'num_channels', 'eps', 'affine'] + __constants__ = ["num_groups", "num_channels", "eps", "affine"] num_groups: int num_channels: int eps: float affine: bool - def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if num_channels % num_groups != 0: - raise ValueError('num_channels must be divisible by num_groups') + raise ValueError("num_channels must be divisible by num_groups") self.num_groups = num_groups self.num_channels = num_channels @@ -274,8 +299,8 @@ def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) else: - self.register_parameter('weight', None) - self.register_parameter('bias', None) + self.register_parameter("weight", None) + self.register_parameter("bias", None) self.reset_parameters() @@ -285,12 +310,12 @@ def reset_parameters(self) -> None: init.zeros_(self.bias) def forward(self, input: Tensor) -> Tensor: - return F.group_norm( - input, self.num_groups, self.weight, self.bias, self.eps) + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) def extra_repr(self) -> str: - return '{num_groups}, {num_channels}, eps={eps}, ' \ - 'affine={affine}'.format(**self.__dict__) + return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format( + **self.__dict__ + ) class RMSNorm(Module): @@ -333,14 +358,20 @@ class RMSNorm(Module): >>> rms_norm(input) """ - __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: Optional[float] elementwise_affine: bool - def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elementwise_affine: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + normalized_shape: _shape_t, + eps: Optional[float] = None, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment @@ -349,9 +380,11 @@ def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elem self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: - self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) + self.weight = Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) else: - self.register_parameter('weight', None) + self.register_parameter("weight", None) self.reset_parameters() def reset_parameters(self) -> None: @@ -371,8 +404,10 @@ def extra_repr(self) -> str: """ Extra information about the module. """ - return '{normalized_shape}, eps={eps}, ' \ - 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) # TODO: ContrastiveNorm2d diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 4b29fbf1c8f497..cfce075e5dfe8d 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -1,22 +1,37 @@ # mypy: allow-untyped-defs -from .module import Module -from .utils import _pair, _quadruple, _ntuple -from .. import functional as F +from typing import Sequence, Tuple +import torch.nn.functional as F from torch import Tensor -from ..common_types import _size_2_t, _size_4_t, _size_6_t -from typing import Sequence, Tuple +from torch.nn.common_types import _size_2_t, _size_4_t, _size_6_t + +from .module import Module +from .utils import _ntuple, _pair, _quadruple # TODO: grad_output size asserts in THNN -__all__ = ['CircularPad1d', 'CircularPad2d', 'CircularPad3d', 'ConstantPad1d', 'ConstantPad2d', - 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d', - 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d', 'ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d'] +__all__ = [ + "CircularPad1d", + "CircularPad2d", + "CircularPad3d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", + "ReflectionPad1d", + "ReflectionPad2d", + "ReflectionPad3d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad1d", + "ZeroPad2d", + "ZeroPad3d", +] class _CircularPadNd(Module): - __constants__ = ['padding'] + __constants__ = ["padding"] padding: Sequence[int] def _check_input_dim(self, input): @@ -24,10 +39,10 @@ def _check_input_dim(self, input): def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) - return F.pad(input, self.padding, 'circular') + return F.pad(input, self.padding, "circular") def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" class CircularPad1d(_CircularPadNd): @@ -76,9 +91,7 @@ def __init__(self, padding: _size_2_t) -> None: def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: - raise ValueError( - f"expected 2D or 3D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") class CircularPad2d(_CircularPadNd): @@ -137,9 +150,7 @@ def __init__(self, padding: _size_4_t) -> None: def _check_input_dim(self, input): if input.dim() != 3 and input.dim() != 4: - raise ValueError( - f"expected 3D or 4D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)") class CircularPad3d(_CircularPadNd): @@ -188,13 +199,11 @@ def __init__(self, padding: _size_6_t) -> None: def _check_input_dim(self, input): if input.dim() != 4 and input.dim() != 5: - raise ValueError( - f"expected 4D or 5D input (got {input.dim()}D input)" - ) + raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)") class _ConstantPadNd(Module): - __constants__ = ['padding', 'value'] + __constants__ = ["padding", "value"] value: float padding: Sequence[int] @@ -203,10 +212,10 @@ def __init__(self, value: float) -> None: self.value = value def forward(self, input: Tensor) -> Tensor: - return F.pad(input, self.padding, 'constant', self.value) + return F.pad(input, self.padding, "constant", self.value) def extra_repr(self) -> str: - return f'padding={self.padding}, value={self.value}' + return f"padding={self.padding}, value={self.value}" class ConstantPad1d(_ConstantPadNd): @@ -303,7 +312,7 @@ class ConstantPad2d(_ConstantPadNd): [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) """ - __constants__ = ['padding', 'value'] + __constants__ = ["padding", "value"] padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t, value: float) -> None: @@ -352,14 +361,14 @@ def __init__(self, padding: _size_6_t, value: float) -> None: class _ReflectionPadNd(Module): - __constants__ = ['padding'] + __constants__ = ["padding"] padding: Sequence[int] def forward(self, input: Tensor) -> Tensor: - return F.pad(input, self.padding, 'reflect') + return F.pad(input, self.padding, "reflect") def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" class ReflectionPad1d(_ReflectionPadNd): @@ -511,14 +520,14 @@ def __init__(self, padding: _size_6_t) -> None: class _ReplicationPadNd(Module): - __constants__ = ['padding'] + __constants__ = ["padding"] padding: Sequence[int] def forward(self, input: Tensor) -> Tensor: - return F.pad(input, self.padding, 'replicate') + return F.pad(input, self.padding, "replicate") def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" class ReplicationPad1d(_ReplicationPadNd): @@ -702,10 +711,11 @@ class ZeroPad1d(ConstantPad1d): padding: Tuple[int, int] def __init__(self, padding: _size_2_t) -> None: - super().__init__(padding, 0.) + super().__init__(padding, 0.0) def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" + class ZeroPad2d(ConstantPad2d): r"""Pads the input tensor boundaries with zero. @@ -755,10 +765,11 @@ class ZeroPad2d(ConstantPad2d): padding: Tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: - super().__init__(padding, 0.) + super().__init__(padding, 0.0) def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" + class ZeroPad3d(ConstantPad3d): r"""Pads the input tensor boundaries with zero. @@ -796,7 +807,7 @@ class ZeroPad3d(ConstantPad3d): padding: Tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: - super().__init__(padding, 0.) + super().__init__(padding, 0.0) def extra_repr(self) -> str: - return f'{self.padding}' + return f"{self.padding}" diff --git a/torch/nn/modules/pixelshuffle.py b/torch/nn/modules/pixelshuffle.py index e6136350b3a409..26ca996d86518a 100644 --- a/torch/nn/modules/pixelshuffle.py +++ b/torch/nn/modules/pixelshuffle.py @@ -1,9 +1,11 @@ +import torch.nn.functional as F +from torch import Tensor + from .module import Module -from .. import functional as F -from torch import Tensor -__all__ = ['PixelShuffle', 'PixelUnshuffle'] +__all__ = ["PixelShuffle", "PixelUnshuffle"] + class PixelShuffle(Module): r"""Rearrange elements in a tensor according to an upscaling factor. @@ -46,7 +48,7 @@ class PixelShuffle(Module): https://arxiv.org/abs/1609.05158 """ - __constants__ = ['upscale_factor'] + __constants__ = ["upscale_factor"] upscale_factor: int def __init__(self, upscale_factor: int) -> None: @@ -57,7 +59,7 @@ def forward(self, input: Tensor) -> Tensor: return F.pixel_shuffle(input, self.upscale_factor) def extra_repr(self) -> str: - return f'upscale_factor={self.upscale_factor}' + return f"upscale_factor={self.upscale_factor}" class PixelUnshuffle(Module): @@ -99,7 +101,7 @@ class PixelUnshuffle(Module): https://arxiv.org/abs/1609.05158 """ - __constants__ = ['downscale_factor'] + __constants__ = ["downscale_factor"] downscale_factor: int def __init__(self, downscale_factor: int) -> None: @@ -110,4 +112,4 @@ def forward(self, input: Tensor) -> Tensor: return F.pixel_unshuffle(input, self.downscale_factor) def extra_repr(self) -> str: - return f'downscale_factor={self.downscale_factor}' + return f"downscale_factor={self.downscale_factor}" diff --git a/torch/nn/modules/pooling.py b/torch/nn/modules/pooling.py index 61ce56390981ab..1913afa145d42d 100644 --- a/torch/nn/modules/pooling.py +++ b/torch/nn/modules/pooling.py @@ -1,27 +1,68 @@ from typing import List, Optional +import torch.nn.functional as F from torch import Tensor -from .module import Module -from .utils import _single, _pair, _triple -from .. import functional as F +from torch.nn.common_types import ( + _ratio_2_t, + _ratio_3_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_opt_t, + _size_any_t, +) -from ..common_types import (_size_any_t, _size_1_t, _size_2_t, _size_3_t, - _ratio_3_t, _ratio_2_t, _size_any_opt_t, _size_2_opt_t, _size_3_opt_t) +from .module import Module +from .utils import _pair, _single, _triple + + +__all__ = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "MaxUnpool1d", + "MaxUnpool2d", + "MaxUnpool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "FractionalMaxPool3d", + "LPPool1d", + "LPPool2d", + "LPPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] -__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', - 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'LPPool1d', - 'LPPool2d', 'LPPool3d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', - 'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d'] class _MaxPoolNd(Module): - __constants__ = ['kernel_size', 'stride', 'padding', 'dilation', - 'return_indices', 'ceil_mode'] + __constants__ = [ + "kernel_size", + "stride", + "padding", + "dilation", + "return_indices", + "ceil_mode", + ] return_indices: bool ceil_mode: bool - def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None, - padding: _size_any_t = 0, dilation: _size_any_t = 1, - return_indices: bool = False, ceil_mode: bool = False) -> None: + def __init__( + self, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + padding: _size_any_t = 0, + dilation: _size_any_t = 1, + return_indices: bool = False, + ceil_mode: bool = False, + ) -> None: super().__init__() self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size @@ -31,8 +72,10 @@ def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = Non self.ceil_mode = ceil_mode def extra_repr(self) -> str: - return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \ - ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__) + return ( + "kernel_size={kernel_size}, stride={stride}, padding={padding}" + ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__) + ) class MaxPool1d(_MaxPoolNd): @@ -88,9 +131,15 @@ class MaxPool1d(_MaxPoolNd): dilation: _size_1_t def forward(self, input: Tensor): - return F.max_pool1d(input, self.kernel_size, self.stride, - self.padding, self.dilation, ceil_mode=self.ceil_mode, - return_indices=self.return_indices) + return F.max_pool1d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) class MaxPool2d(_MaxPoolNd): @@ -161,9 +210,15 @@ class MaxPool2d(_MaxPoolNd): dilation: _size_2_t def forward(self, input: Tensor): - return F.max_pool2d(input, self.kernel_size, self.stride, - self.padding, self.dilation, ceil_mode=self.ceil_mode, - return_indices=self.return_indices) + return F.max_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) class MaxPool3d(_MaxPoolNd): @@ -238,15 +293,20 @@ class MaxPool3d(_MaxPoolNd): dilation: _size_3_t def forward(self, input: Tensor): - return F.max_pool3d(input, self.kernel_size, self.stride, - self.padding, self.dilation, ceil_mode=self.ceil_mode, - return_indices=self.return_indices) + return F.max_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + ceil_mode=self.ceil_mode, + return_indices=self.return_indices, + ) class _MaxUnpoolNd(Module): - def extra_repr(self) -> str: - return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}' + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" class MaxUnpool1d(_MaxUnpoolNd): @@ -312,15 +372,23 @@ class MaxUnpool1d(_MaxUnpoolNd): stride: _size_1_t padding: _size_1_t - def __init__(self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0) -> None: + def __init__( + self, + kernel_size: _size_1_t, + stride: Optional[_size_1_t] = None, + padding: _size_1_t = 0, + ) -> None: super().__init__() self.kernel_size = _single(kernel_size) self.stride = _single(stride if (stride is not None) else kernel_size) self.padding = _single(padding) - def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - return F.max_unpool1d(input, indices, self.kernel_size, self.stride, - self.padding, output_size) + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool1d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) class MaxUnpool2d(_MaxUnpoolNd): @@ -399,15 +467,23 @@ class MaxUnpool2d(_MaxUnpoolNd): stride: _size_2_t padding: _size_2_t - def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0) -> None: + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ) -> None: super().__init__() self.kernel_size = _pair(kernel_size) self.stride = _pair(stride if (stride is not None) else kernel_size) self.padding = _pair(padding) - def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - return F.max_unpool2d(input, indices, self.kernel_size, self.stride, - self.padding, output_size) + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool2d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) class MaxUnpool3d(_MaxUnpoolNd): @@ -469,22 +545,36 @@ class MaxUnpool3d(_MaxUnpoolNd): stride: _size_3_t padding: _size_3_t - def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0) -> None: + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ) -> None: super().__init__() self.kernel_size = _triple(kernel_size) self.stride = _triple(stride if (stride is not None) else kernel_size) self.padding = _triple(padding) - def forward(self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None) -> Tensor: - return F.max_unpool3d(input, indices, self.kernel_size, self.stride, - self.padding, output_size) + def forward( + self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None + ) -> Tensor: + return F.max_unpool3d( + input, indices, self.kernel_size, self.stride, self.padding, output_size + ) class _AvgPoolNd(Module): - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad'] + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + ] def extra_repr(self) -> str: - return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}' + return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}" class AvgPool1d(_AvgPoolNd): @@ -542,8 +632,14 @@ class AvgPool1d(_AvgPoolNd): ceil_mode: bool count_include_pad: bool - def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False, - count_include_pad: bool = True) -> None: + def __init__( + self, + kernel_size: _size_1_t, + stride: _size_1_t = None, + padding: _size_1_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> None: super().__init__() self.kernel_size = _single(kernel_size) self.stride = _single(stride if stride is not None else kernel_size) @@ -553,8 +649,13 @@ def __init__(self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _s def forward(self, input: Tensor) -> Tensor: return F.avg_pool1d( - input, self.kernel_size, self.stride, self.padding, self.ceil_mode, - self.count_include_pad) + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + ) class AvgPool2d(_AvgPoolNd): @@ -619,7 +720,14 @@ class AvgPool2d(_AvgPoolNd): >>> output = m(input) """ - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] kernel_size: _size_2_t stride: _size_2_t @@ -627,8 +735,15 @@ class AvgPool2d(_AvgPoolNd): ceil_mode: bool count_include_pad: bool - def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, - ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: + def __init__( + self, + kernel_size: _size_2_t, + stride: Optional[_size_2_t] = None, + padding: _size_2_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: super().__init__() self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size @@ -638,8 +753,15 @@ def __init__(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, p self.divisor_override = divisor_override def forward(self, input: Tensor) -> Tensor: - return F.avg_pool2d(input, self.kernel_size, self.stride, - self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) + return F.avg_pool2d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) class AvgPool3d(_AvgPoolNd): @@ -711,7 +833,14 @@ class AvgPool3d(_AvgPoolNd): >>> output = m(input) """ - __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode', 'count_include_pad', 'divisor_override'] + __constants__ = [ + "kernel_size", + "stride", + "padding", + "ceil_mode", + "count_include_pad", + "divisor_override", + ] kernel_size: _size_3_t stride: _size_3_t @@ -719,8 +848,15 @@ class AvgPool3d(_AvgPoolNd): ceil_mode: bool count_include_pad: bool - def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, - ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> None: + def __init__( + self, + kernel_size: _size_3_t, + stride: Optional[_size_3_t] = None, + padding: _size_3_t = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[int] = None, + ) -> None: super().__init__() self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size @@ -730,14 +866,21 @@ def __init__(self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, p self.divisor_override = divisor_override def forward(self, input: Tensor) -> Tensor: - return F.avg_pool3d(input, self.kernel_size, self.stride, - self.padding, self.ceil_mode, self.count_include_pad, self.divisor_override) + return F.avg_pool3d( + input, + self.kernel_size, + self.stride, + self.padding, + self.ceil_mode, + self.count_include_pad, + self.divisor_override, + ) def __setstate__(self, d): super().__setstate__(d) - self.__dict__.setdefault('padding', 0) - self.__dict__.setdefault('ceil_mode', False) - self.__dict__.setdefault('count_include_pad', True) + self.__dict__.setdefault("padding", 0) + self.__dict__.setdefault("ceil_mode", False) + self.__dict__.setdefault("count_include_pad", True) class FractionalMaxPool2d(Module): @@ -782,37 +925,51 @@ class FractionalMaxPool2d(Module): https://arxiv.org/abs/1412.6071 """ - __constants__ = ['kernel_size', 'return_indices', 'output_size', - 'output_ratio'] + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] kernel_size: _size_2_t return_indices: bool output_size: _size_2_t output_ratio: _ratio_2_t - def __init__(self, kernel_size: _size_2_t, output_size: Optional[_size_2_t] = None, - output_ratio: Optional[_ratio_2_t] = None, - return_indices: bool = False, _random_samples=None) -> None: + def __init__( + self, + kernel_size: _size_2_t, + output_size: Optional[_size_2_t] = None, + output_ratio: Optional[_ratio_2_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: super().__init__() self.kernel_size = _pair(kernel_size) self.return_indices = return_indices - self.register_buffer('_random_samples', _random_samples) + self.register_buffer("_random_samples", _random_samples) self.output_size = _pair(output_size) if output_size is not None else None self.output_ratio = _pair(output_ratio) if output_ratio is not None else None if output_size is None and output_ratio is None: - raise ValueError("FractionalMaxPool2d requires specifying either " - "an output size, or a pooling ratio") + raise ValueError( + "FractionalMaxPool2d requires specifying either " + "an output size, or a pooling ratio" + ) if output_size is not None and output_ratio is not None: - raise ValueError("only one of output_size and output_ratio may be specified") + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) if self.output_ratio is not None: if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1): - raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})") + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) def forward(self, input: Tensor): return F.fractional_max_pool2d( - input, self.kernel_size, self.output_size, self.output_ratio, + input, + self.kernel_size, + self.output_size, + self.output_ratio, self.return_indices, - _random_samples=self._random_samples) + _random_samples=self._random_samples, + ) class FractionalMaxPool3d(Module): @@ -854,46 +1011,69 @@ class FractionalMaxPool3d(Module): https://arxiv.org/abs/1412.6071 """ - __constants__ = ['kernel_size', 'return_indices', 'output_size', - 'output_ratio'] + __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"] kernel_size: _size_3_t return_indices: bool output_size: _size_3_t output_ratio: _ratio_3_t - def __init__(self, kernel_size: _size_3_t, output_size: Optional[_size_3_t] = None, - output_ratio: Optional[_ratio_3_t] = None, - return_indices: bool = False, _random_samples=None) -> None: + def __init__( + self, + kernel_size: _size_3_t, + output_size: Optional[_size_3_t] = None, + output_ratio: Optional[_ratio_3_t] = None, + return_indices: bool = False, + _random_samples=None, + ) -> None: super().__init__() self.kernel_size = _triple(kernel_size) self.return_indices = return_indices - self.register_buffer('_random_samples', _random_samples) + self.register_buffer("_random_samples", _random_samples) self.output_size = _triple(output_size) if output_size is not None else None self.output_ratio = _triple(output_ratio) if output_ratio is not None else None if output_size is None and output_ratio is None: - raise ValueError("FractionalMaxPool3d requires specifying either " - "an output size, or a pooling ratio") + raise ValueError( + "FractionalMaxPool3d requires specifying either " + "an output size, or a pooling ratio" + ) if output_size is not None and output_ratio is not None: - raise ValueError("only one of output_size and output_ratio may be specified") + raise ValueError( + "only one of output_size and output_ratio may be specified" + ) if self.output_ratio is not None: - if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1 and 0 < self.output_ratio[2] < 1): - raise ValueError(f"output_ratio must be between 0 and 1 (got {output_ratio})") + if not ( + 0 < self.output_ratio[0] < 1 + and 0 < self.output_ratio[1] < 1 + and 0 < self.output_ratio[2] < 1 + ): + raise ValueError( + f"output_ratio must be between 0 and 1 (got {output_ratio})" + ) def forward(self, input: Tensor): return F.fractional_max_pool3d( - input, self.kernel_size, self.output_size, self.output_ratio, + input, + self.kernel_size, + self.output_size, + self.output_ratio, self.return_indices, - _random_samples=self._random_samples) + _random_samples=self._random_samples, + ) class _LPPoolNd(Module): - __constants__ = ['norm_type', 'kernel_size', 'stride', 'ceil_mode'] + __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"] norm_type: float ceil_mode: bool - def __init__(self, norm_type: float, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None, - ceil_mode: bool = False) -> None: + def __init__( + self, + norm_type: float, + kernel_size: _size_any_t, + stride: Optional[_size_any_t] = None, + ceil_mode: bool = False, + ) -> None: super().__init__() self.norm_type = norm_type self.kernel_size = kernel_size @@ -901,8 +1081,10 @@ def __init__(self, norm_type: float, kernel_size: _size_any_t, stride: Optional[ self.ceil_mode = ceil_mode def extra_repr(self) -> str: - return 'norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, ' \ - 'ceil_mode={ceil_mode}'.format(**self.__dict__) + return ( + "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, " + "ceil_mode={ceil_mode}".format(**self.__dict__) + ) class LPPool1d(_LPPoolNd): @@ -942,8 +1124,9 @@ class LPPool1d(_LPPoolNd): stride: _size_1_t def forward(self, input: Tensor) -> Tensor: - return F.lp_pool1d(input, float(self.norm_type), self.kernel_size, - self.stride, self.ceil_mode) + return F.lp_pool1d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) class LPPool2d(_LPPoolNd): @@ -996,8 +1179,9 @@ class LPPool2d(_LPPoolNd): stride: _size_2_t def forward(self, input: Tensor) -> Tensor: - return F.lp_pool2d(input, float(self.norm_type), self.kernel_size, - self.stride, self.ceil_mode) + return F.lp_pool2d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) class LPPool3d(_LPPoolNd): @@ -1054,21 +1238,25 @@ class LPPool3d(_LPPoolNd): stride: _size_3_t def forward(self, input: Tensor) -> Tensor: - return F.lp_pool3d(input, float(self.norm_type), self.kernel_size, - self.stride, self.ceil_mode) + return F.lp_pool3d( + input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode + ) class _AdaptiveMaxPoolNd(Module): - __constants__ = ['output_size', 'return_indices'] + __constants__ = ["output_size", "return_indices"] return_indices: bool - def __init__(self, output_size: _size_any_opt_t, return_indices: bool = False) -> None: + def __init__( + self, output_size: _size_any_opt_t, return_indices: bool = False + ) -> None: super().__init__() self.output_size = output_size self.return_indices = return_indices def extra_repr(self) -> str: - return f'output_size={self.output_size}' + return f"output_size={self.output_size}" + # FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and # output shapes are, and how the operation computes output. @@ -1190,14 +1378,14 @@ def forward(self, input: Tensor): class _AdaptiveAvgPoolNd(Module): - __constants__ = ['output_size'] + __constants__ = ["output_size"] def __init__(self, output_size: _size_any_opt_t) -> None: super().__init__() self.output_size = output_size def extra_repr(self) -> str: - return f'output_size={self.output_size}' + return f"output_size={self.output_size}" class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd): diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 8ba4f9f0831997..2870eee29412af 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -1,24 +1,34 @@ # mypy: allow-untyped-defs import math -import warnings import numbers +import warnings import weakref -from typing import List, Tuple, Optional, overload +from typing import List, Optional, overload, Tuple from typing_extensions import deprecated import torch -from torch import Tensor +from torch import _VF, Tensor +from torch.nn import init +from torch.nn.parameter import Parameter +from torch.nn.utils.rnn import PackedSequence + from .module import Module -from ..parameter import Parameter -from ..utils.rnn import PackedSequence -from .. import init -from ... import _VF -__all__ = ['RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell'] + +__all__ = [ + "RNNBase", + "RNN", + "LSTM", + "GRU", + "RNNCellBase", + "RNNCell", + "LSTMCell", + "GRUCell", +] _rnn_impls = { - 'RNN_TANH': _VF.rnn_tanh, - 'RNN_RELU': _VF.rnn_relu, + "RNN_TANH": _VF.rnn_tanh, + "RNN_RELU": _VF.rnn_relu, } @@ -47,9 +57,18 @@ class RNNBase(Module): LSTM and GRU classes override some methods implemented by RNNBase. """ - __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', - 'batch_first', 'dropout', 'bidirectional', 'proj_size'] - __jit_unused_properties__ = ['all_weights'] + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "proj_size", + ] + __jit_unused_properties__ = ["all_weights"] mode: str input_size: int @@ -61,11 +80,21 @@ class RNNBase(Module): bidirectional: bool proj_size: int - def __init__(self, mode: str, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, batch_first: bool = False, - dropout: float = 0., bidirectional: bool = False, proj_size: int = 0, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + mode: str, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.mode = mode self.input_size = input_size @@ -79,35 +108,46 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = [] num_directions = 2 if bidirectional else 1 - if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ - isinstance(dropout, bool): - raise ValueError("dropout should be a number in range [0, 1] " - "representing the probability of an element being " - "zeroed") + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed" + ) if dropout > 0 and num_layers == 1: - warnings.warn("dropout option adds dropout after all but last " - "recurrent layer, so non-zero dropout expects " - f"num_layers greater than 1, but got dropout={dropout} and " - f"num_layers={num_layers}") + warnings.warn( + "dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}" + ) if not isinstance(hidden_size, int): - raise TypeError(f"hidden_size should be of type int, got: {type(hidden_size).__name__}") + raise TypeError( + f"hidden_size should be of type int, got: {type(hidden_size).__name__}" + ) if hidden_size <= 0: raise ValueError("hidden_size must be greater than zero") if num_layers <= 0: raise ValueError("num_layers must be greater than zero") if proj_size < 0: - raise ValueError("proj_size should be a positive integer or zero to disable projections") + raise ValueError( + "proj_size should be a positive integer or zero to disable projections" + ) if proj_size >= hidden_size: raise ValueError("proj_size has to be smaller than hidden_size") - if mode == 'LSTM': + if mode == "LSTM": gate_size = 4 * hidden_size - elif mode == 'GRU': + elif mode == "GRU": gate_size = 3 * hidden_size - elif mode == 'RNN_TANH': + elif mode == "RNN_TANH": gate_size = hidden_size - elif mode == 'RNN_RELU': + elif mode == "RNN_RELU": gate_size = hidden_size else: raise ValueError("Unrecognized RNN mode: " + mode) @@ -117,10 +157,16 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, for layer in range(num_layers): for direction in range(num_directions): real_hidden_size = proj_size if proj_size > 0 else hidden_size - layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions - - w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs)) - w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs)) + layer_input_size = ( + input_size if layer == 0 else real_hidden_size * num_directions + ) + + w_ih = Parameter( + torch.empty((gate_size, layer_input_size), **factory_kwargs) + ) + w_hh = Parameter( + torch.empty((gate_size, real_hidden_size), **factory_kwargs) + ) b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) # Second bias vector included for CuDNN compatibility. Only one # bias vector is needed in standard definition. @@ -132,18 +178,20 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, else: layer_params = (w_ih, w_hh) else: - w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs)) + w_hr = Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) if bias: layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) else: layer_params = (w_ih, w_hh, w_hr) - suffix = '_reverse' if direction == 1 else '' - param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] + suffix = "_reverse" if direction == 1 else "" + param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] if bias: - param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] + param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] if self.proj_size > 0: - param_names += ['weight_hr_l{}{}'] + param_names += ["weight_hr_l{}{}"] param_names = [x.format(layer, suffix) for x in param_names] for name, param in zip(param_names, layer_params): @@ -156,10 +204,13 @@ def __init__(self, mode: str, input_size: int, hidden_size: int, self.reset_parameters() def _init_flat_weights(self): - self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None - for wn in self._flat_weights_names] - self._flat_weight_refs = [weakref.ref(w) if w is not None else None - for w in self._flat_weights] + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] self.flatten_parameters() def __setattr__(self, attr, value): @@ -189,8 +240,10 @@ def flatten_parameters(self) -> None: dtype = first_fw.dtype for fw in self._flat_weights: if ( - not isinstance(fw, Tensor) or not (fw.dtype == dtype) or - not fw.is_cuda or not torch.backends.cudnn.is_acceptable(fw) + not isinstance(fw, Tensor) + or not (fw.dtype == dtype) + or not fw.is_cuda + or not torch.backends.cudnn.is_acceptable(fw) ): return @@ -213,10 +266,16 @@ def flatten_parameters(self) -> None: if self.proj_size > 0: num_weights += 1 torch._cudnn_rnn_flatten_weight( - self._flat_weights, num_weights, - self.input_size, rnn.get_cudnn_mode(self.mode), - self.hidden_size, self.proj_size, self.num_layers, - self.batch_first, bool(self.bidirectional)) + self._flat_weights, + num_weights, + self.input_size, + rnn.get_cudnn_mode(self.mode), + self.hidden_size, + self.proj_size, + self.num_layers, + self.batch_first, + bool(self.bidirectional), + ) def _apply(self, fn, recurse=True): self._flat_weight_refs = [] @@ -236,32 +295,51 @@ def reset_parameters(self) -> None: def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: if not torch.jit.is_scripting(): - if input.dtype != self._flat_weights[0].dtype and not torch._C._is_any_autocast_enabled(): - raise ValueError(f'input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}') + if ( + input.dtype != self._flat_weights[0].dtype + and not torch._C._is_any_autocast_enabled() + ): + raise ValueError( + f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" + ) expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError( - f'input must have {expected_input_dim} dimensions, got {input.dim()}') + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) if self.input_size != input.size(-1): raise RuntimeError( - f'input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}') + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) - def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 if self.proj_size > 0: - expected_hidden_size = (self.num_layers * num_directions, - mini_batch, self.proj_size) + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.proj_size, + ) else: - expected_hidden_size = (self.num_layers * num_directions, - mini_batch, self.hidden_size) + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) return expected_hidden_size - def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], - msg: str = 'Expected hidden size {}, got {}') -> None: + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: Tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: if hx.size() != expected_hidden_size: raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) @@ -276,7 +354,9 @@ def _weights_have_changed(self): break return weights_changed - def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]): + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ): self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) @@ -287,21 +367,20 @@ def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): return hx return _apply_permutation(hx, permutation) - def extra_repr(self) -> str: - s = '{input_size}, {hidden_size}' + s = "{input_size}, {hidden_size}" if self.proj_size != 0: - s += ', proj_size={proj_size}' + s += ", proj_size={proj_size}" if self.num_layers != 1: - s += ', num_layers={num_layers}' + s += ", num_layers={num_layers}" if self.bias is not True: - s += ', bias={bias}' + s += ", bias={bias}" if self.batch_first is not False: - s += ', batch_first={batch_first}' + s += ", batch_first={batch_first}" if self.dropout != 0: - s += ', dropout={dropout}' + s += ", dropout={dropout}" if self.bidirectional is not False: - s += ', bidirectional={bidirectional}' + s += ", bidirectional={bidirectional}" return s.format(**self.__dict__) def _update_flat_weights(self): @@ -314,17 +393,17 @@ def __getstate__(self): self._update_flat_weights() # Don't serialize the weight references. state = self.__dict__.copy() - del state['_flat_weight_refs'] + del state["_flat_weight_refs"] return state def __setstate__(self, d): super().__setstate__(d) - if 'all_weights' in d: - self._all_weights = d['all_weights'] + if "all_weights" in d: + self._all_weights = d["all_weights"] # In PyTorch 1.8 we added a proj_size member variable to LSTM. # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 # don't have it, so to preserve compatibility we set proj_size here. - if 'proj_size' not in d: + if "proj_size" not in d: self.proj_size = 0 if not isinstance(self._all_weights[0][0], str): @@ -334,9 +413,14 @@ def __setstate__(self, d): self._all_weights = [] for layer in range(num_layers): for direction in range(num_directions): - suffix = '_reverse' if direction == 1 else '' - weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', - 'bias_hh_l{}{}', 'weight_hr_l{}{}'] + suffix = "_reverse" if direction == 1 else "" + weights = [ + "weight_ih_l{}{}", + "weight_hh_l{}{}", + "bias_ih_l{}{}", + "bias_hh_l{}{}", + "weight_hr_l{}{}", + ] weights = [x.format(layer, suffix) for x in weights] if self.bias: if self.proj_size > 0: @@ -348,19 +432,27 @@ def __setstate__(self, d): else: if self.proj_size > 0: self._all_weights += [weights[:2]] + [weights[-1:]] - self._flat_weights_names.extend(weights[:2] + [weights[-1:]]) + self._flat_weights_names.extend( + weights[:2] + [weights[-1:]] + ) else: self._all_weights += [weights[:2]] self._flat_weights_names.extend(weights[:2]) - self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None - for wn in self._flat_weights_names] + self._flat_weights = [ + getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names + ] - self._flat_weight_refs = [weakref.ref(w) if w is not None else None - for w in self._flat_weights] + self._flat_weight_refs = [ + weakref.ref(w) if w is not None else None for w in self._flat_weights + ] @property def all_weights(self) -> List[List[Parameter]]: - return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] + return [ + [getattr(self, weight) for weight in weights] + for weights in self._all_weights + ] def _replicate_for_data_parallel(self): replica = super()._replicate_for_data_parallel() @@ -501,10 +593,19 @@ def forward(x, h_0=None): """ @overload - def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, - nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False, - dropout: float = 0., bidirectional: bool = False, device=None, - dtype=None) -> None: + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + nonlinearity: str = "tanh", + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... @overload @@ -512,29 +613,37 @@ def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): - if 'proj_size' in kwargs: - raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) if len(args) > 3: self.nonlinearity = args[3] args = args[:3] + args[4:] else: - self.nonlinearity = kwargs.pop('nonlinearity', 'tanh') - if self.nonlinearity == 'tanh': - mode = 'RNN_TANH' - elif self.nonlinearity == 'relu': - mode = 'RNN_RELU' + self.nonlinearity = kwargs.pop("nonlinearity", "tanh") + if self.nonlinearity == "tanh": + mode = "RNN_TANH" + elif self.nonlinearity == "relu": + mode = "RNN_RELU" else: - raise ValueError(f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'.") + raise ValueError( + f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." + ) super().__init__(mode, *args, **kwargs) @overload @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: pass @overload @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: pass def forward(self, input, hx=None): # noqa: F811 @@ -548,9 +657,13 @@ def forward(self, input, hx=None): # noqa: F811 max_batch_size = batch_sizes[0] # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate if hx is None: - hx = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -558,7 +671,9 @@ def forward(self, input, hx=None): # noqa: F811 else: batch_sizes = None if input.dim() not in (2, 3): - raise ValueError(f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead") + raise ValueError( + f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" + ) is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: @@ -566,19 +681,25 @@ def forward(self, input, hx=None): # noqa: F811 if hx is not None: if hx.dim() != 2: raise RuntimeError( - f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) hx = hx.unsqueeze(1) else: if hx is not None and hx.dim() != 3: raise RuntimeError( - f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: - hx = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -586,31 +707,65 @@ def forward(self, input, hx=None): # noqa: F811 assert hx is not None self.check_forward_args(input, hx, batch_sizes) - assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU' + assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" if batch_sizes is None: - if self.mode == 'RNN_TANH': - result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers, - self.dropout, self.training, self.bidirectional, - self.batch_first) + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) else: - result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers, - self.dropout, self.training, self.bidirectional, - self.batch_first) + result = _VF.rnn_relu( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) else: - if self.mode == 'RNN_TANH': - result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias, - self.num_layers, self.dropout, self.training, - self.bidirectional) + if self.mode == "RNN_TANH": + result = _VF.rnn_tanh( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) else: - result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias, - self.num_layers, self.dropout, self.training, - self.bidirectional) + result = _VF.rnn_relu( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) output = result[0] hidden = result[1] if isinstance(orig_input, PackedSequence): - output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) return output_packed, self.permute_hidden(hidden, unsorted_indices) if not is_batched: # type: ignore[possibly-undefined] @@ -619,6 +774,7 @@ def forward(self, input, hx=None): # noqa: F811 return output, self.permute_hidden(hidden, unsorted_indices) + # XXX: LSTM and GRU implementation is different from RNNBase, this is because: # 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in # its current state could not support the python Union Type or Any Type @@ -795,9 +951,19 @@ class LSTM(RNNBase): """ @overload - def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, - batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, - proj_size: int = 0, device=None, dtype=None) -> None: + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: ... @overload @@ -805,52 +971,69 @@ def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): - super().__init__('LSTM', *args, **kwargs) + super().__init__("LSTM", *args, **kwargs) - def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + def get_expected_cell_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = int(batch_sizes[0]) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 - expected_hidden_size = (self.num_layers * num_directions, - mini_batch, self.hidden_size) + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) return expected_hidden_size # In the future, we should prevent mypy from applying contravariance rules here. # See torch/nn/modules/module.py::_forward_unimplemented - def check_forward_args(self, # type: ignore[override] - input: Tensor, - hidden: Tuple[Tensor, Tensor], - batch_sizes: Optional[Tensor], - ): + def check_forward_args( + self, + input: Tensor, + hidden: Tuple[Tensor, Tensor], # type: ignore[override] + batch_sizes: Optional[Tensor], + ): self.check_input(input, batch_sizes) - self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes), - 'Expected hidden[0] size {}, got {}') - self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes), - 'Expected hidden[1] size {}, got {}') + self.check_hidden_size( + hidden[0], + self.get_expected_hidden_size(input, batch_sizes), + "Expected hidden[0] size {}, got {}", + ) + self.check_hidden_size( + hidden[1], + self.get_expected_cell_size(input, batch_sizes), + "Expected hidden[1] size {}, got {}", + ) # Same as above, see torch/nn/modules/module.py::_forward_unimplemented - def permute_hidden(self, # type: ignore[override] - hx: Tuple[Tensor, Tensor], - permutation: Optional[Tensor] - ) -> Tuple[Tensor, Tensor]: + def permute_hidden( + self, + hx: Tuple[Tensor, Tensor], # type: ignore[override] + permutation: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: if permutation is None: return hx - return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) + return _apply_permutation(hx[0], permutation), _apply_permutation( + hx[1], permutation + ) # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload # type: ignore[override] @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None - ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 pass # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None - ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811 pass def forward(self, input, hx=None): # noqa: F811 @@ -866,12 +1049,20 @@ def forward(self, input, hx=None): # noqa: F811 input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] if hx is None: - h_zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, real_hidden_size, - dtype=input.dtype, device=input.device) - c_zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) hx = (h_zeros, c_zeros) else: # Each batch of the hidden state should match the input sequence that @@ -879,7 +1070,9 @@ def forward(self, input, hx=None): # noqa: F811 hx = self.permute_hidden(hx, sorted_indices) else: if input.dim() not in (2, 3): - raise ValueError(f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead") + raise ValueError( + f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: @@ -888,24 +1081,36 @@ def forward(self, input, hx=None): # noqa: F811 sorted_indices = None unsorted_indices = None if hx is None: - h_zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, real_hidden_size, - dtype=input.dtype, device=input.device) - c_zeros = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + h_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + real_hidden_size, + dtype=input.dtype, + device=input.device, + ) + c_zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) hx = (h_zeros, c_zeros) self.check_forward_args(input, hx, batch_sizes) else: if is_batched: - if (hx[0].dim() != 3 or hx[1].dim() != 3): - msg = ("For batched 3-D input, hx and cx should " - f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + if hx[0].dim() != 3 or hx[1].dim() != 3: + msg = ( + "For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) raise RuntimeError(msg) else: if hx[0].dim() != 2 or hx[1].dim() != 2: - msg = ("For unbatched 2-D input, hx and cx should " - f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + msg = ( + "For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" + ) raise RuntimeError(msg) hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) # Each batch of the hidden state should match the input sequence that @@ -914,16 +1119,36 @@ def forward(self, input, hx=None): # noqa: F811 hx = self.permute_hidden(hx, sorted_indices) if batch_sizes is None: - result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, - self.dropout, self.training, self.bidirectional, self.batch_first) + result = _VF.lstm( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) else: - result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias, - self.num_layers, self.dropout, self.training, self.bidirectional) + result = _VF.lstm( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) output = result[0] hidden = result[1:] # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): - output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: # type: ignore[possibly-undefined] @@ -1063,9 +1288,18 @@ class GRU(RNNBase): """ @overload - def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, - batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, - device=None, dtype=None) -> None: + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0.0, + bidirectional: bool = False, + device=None, + dtype=None, + ) -> None: ... @overload @@ -1073,18 +1307,24 @@ def __init__(self, *args, **kwargs): ... def __init__(self, *args, **kwargs): - if 'proj_size' in kwargs: - raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") - super().__init__('GRU', *args, **kwargs) + if "proj_size" in kwargs: + raise ValueError( + "proj_size argument is only supported for LSTM, not RNN or GRU" + ) + super().__init__("GRU", *args, **kwargs) @overload # type: ignore[override] @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811 + def forward( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: # noqa: F811 pass @overload @torch._jit_internal._overload_method # noqa: F811 - def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811 + def forward( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: # noqa: F811 pass def forward(self, input, hx=None): # noqa: F811 @@ -1097,9 +1337,13 @@ def forward(self, input, hx=None): # noqa: F811 max_batch_size = batch_sizes[0] if hx is None: num_directions = 2 if self.bidirectional else 1 - hx = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -1107,7 +1351,9 @@ def forward(self, input, hx=None): # noqa: F811 else: batch_sizes = None if input.dim() not in (2, 3): - raise ValueError(f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead") + raise ValueError( + f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" + ) is_batched = input.dim() == 3 batch_dim = 0 if self.batch_first else 1 if not is_batched: @@ -1115,20 +1361,26 @@ def forward(self, input, hx=None): # noqa: F811 if hx is not None: if hx.dim() != 2: raise RuntimeError( - f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" + ) hx = hx.unsqueeze(1) else: if hx is not None and hx.dim() != 3: raise RuntimeError( - f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" + ) max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 - hx = torch.zeros(self.num_layers * num_directions, - max_batch_size, self.hidden_size, - dtype=input.dtype, device=input.device) + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. @@ -1136,17 +1388,37 @@ def forward(self, input, hx=None): # noqa: F811 self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: - result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers, - self.dropout, self.training, self.bidirectional, self.batch_first) + result = _VF.gru( + input, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + self.batch_first, + ) else: - result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias, - self.num_layers, self.dropout, self.training, self.bidirectional) + result = _VF.gru( + input, + batch_sizes, + hx, + self._flat_weights, + self.bias, + self.num_layers, + self.dropout, + self.training, + self.bidirectional, + ) output = result[0] hidden = result[1] # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): - output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + output_packed = PackedSequence( + output, batch_sizes, sorted_indices, unsorted_indices + ) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: if not is_batched: # type: ignore[possibly-undefined] @@ -1157,7 +1429,7 @@ def forward(self, input, hx=None): # noqa: F811 class RNNCellBase(Module): - __constants__ = ['input_size', 'hidden_size', 'bias'] + __constants__ = ["input_size", "hidden_size", "bias"] input_size: int hidden_size: int @@ -1167,30 +1439,45 @@ class RNNCellBase(Module): # WARNING: bias_ih and bias_hh purposely not defined here. # See https://github.com/pytorch/pytorch/issues/39670 - def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias - self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)) - self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)) + self.weight_ih = Parameter( + torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) + ) + self.weight_hh = Parameter( + torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) + ) if bias: - self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs)) - self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs)) + self.bias_ih = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) + self.bias_hh = Parameter( + torch.empty(num_chunks * hidden_size, **factory_kwargs) + ) else: - self.register_parameter('bias_ih', None) - self.register_parameter('bias_hh', None) + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) self.reset_parameters() def extra_repr(self) -> str: - s = '{input_size}, {hidden_size}' - if 'bias' in self.__dict__ and self.bias is not True: - s += ', bias={bias}' - if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": - s += ', nonlinearity={nonlinearity}' + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" return s.format(**self.__dict__) def reset_parameters(self) -> None: @@ -1254,45 +1541,63 @@ class RNNCell(RNNCellBase): ... output.append(hx) """ - __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity'] + __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] nonlinearity: str - def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) self.nonlinearity = nonlinearity def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: if input.dim() not in (1, 2): - raise ValueError(f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + raise ValueError( + f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) if hx is not None and hx.dim() not in (1, 2): - raise ValueError(f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead") + raise ValueError( + f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: - hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) else: hx = hx.unsqueeze(0) if not is_batched else hx if self.nonlinearity == "tanh": ret = _VF.rnn_tanh_cell( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, ) elif self.nonlinearity == "relu": ret = _VF.rnn_relu_cell( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, ) else: ret = input # TODO: remove when jit supports exception flow - raise RuntimeError( - f"Unknown nonlinearity: {self.nonlinearity}") + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") if not is_batched: ret = ret.squeeze(0) @@ -1360,32 +1665,49 @@ class LSTMCell(RNNCellBase): >>> output = torch.stack(output, dim=0) """ - def __init__(self, input_size: int, hidden_size: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: if input.dim() not in (1, 2): - raise ValueError(f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + raise ValueError( + f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) if hx is not None: for idx, value in enumerate(hx): if value.dim() not in (1, 2): - raise ValueError(f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead") + raise ValueError( + f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: - zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) hx = (zeros, zeros) else: hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx ret = _VF.lstm_cell( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, ) if not is_batched: @@ -1455,29 +1777,44 @@ class GRUCell(RNNCellBase): ... output.append(hx) """ - def __init__(self, input_size: int, hidden_size: int, bias: bool = True, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: if input.dim() not in (1, 2): - raise ValueError(f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + raise ValueError( + f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" + ) if hx is not None and hx.dim() not in (1, 2): - raise ValueError(f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead") + raise ValueError( + f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" + ) is_batched = input.dim() == 2 if not is_batched: input = input.unsqueeze(0) if hx is None: - hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) else: hx = hx.unsqueeze(0) if not is_batched else hx ret = _VF.gru_cell( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, ) if not is_batched: diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 512b17d03222d6..487afd1c4f0f4c 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -3,13 +3,14 @@ import torch from torch import Tensor +from torch.nn import functional as F, init from torch.nn.parameter import Parameter from .module import Module -from .. import functional as F -from .. import init -__all__ = ['Embedding', 'EmbeddingBag'] + +__all__ = ["Embedding", "EmbeddingBag"] + class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -109,8 +110,15 @@ class Embedding(Module): [ 0.6778, 0.5803, 0.2678]], requires_grad=True) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', - 'norm_type', 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -122,31 +130,49 @@ class Embedding(Module): freeze: bool sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + _freeze: bool = False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq if _weight is None: - self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs), - requires_grad=not _freeze) + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs), + requires_grad=not _freeze, + ) self.reset_parameters() else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" self.weight = Parameter(_weight, requires_grad=not _freeze) self.sparse = sparse @@ -162,27 +188,40 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input: Tensor) -> Tensor: return F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}' + s = "{num_embeddings}, {embedding_dim}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.max_norm is not None: - s += ', max_norm={max_norm}' + s += ", max_norm={max_norm}" if self.norm_type != 2: - s += ', norm_type={norm_type}' + s += ", norm_type={norm_type}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) @classmethod - def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, - sparse=False): + def from_pretrained( + cls, + embeddings, + freeze=True, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): r"""Create Embedding instance from given 2-dimensional FloatTensor. Args: @@ -209,8 +248,9 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, >>> embedding(input) tensor([[ 4.0000, 5.1000, 6.3000]]) """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' + assert ( + embeddings.dim() == 2 + ), "Embeddings parameter is expected to be 2-dimensional" rows, cols = embeddings.shape embedding = cls( num_embeddings=rows, @@ -221,7 +261,8 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) + sparse=sparse, + ) return embedding @@ -303,9 +344,17 @@ class EmbeddingBag(Module): mode='sum') """ - __constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type', - 'scale_grad_by_freq', 'mode', 'sparse', 'include_last_offset', - 'padding_idx'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "max_norm", + "norm_type", + "scale_grad_by_freq", + "mode", + "sparse", + "include_last_offset", + "padding_idx", + ] num_embeddings: int embedding_dim: int @@ -318,12 +367,22 @@ class EmbeddingBag(Module): include_last_offset: bool padding_idx: Optional[int] - def __init__(self, num_embeddings: int, embedding_dim: int, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None, - include_last_offset: bool = False, padding_idx: Optional[int] = None, - device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + _weight: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -332,17 +391,25 @@ def __init__(self, num_embeddings: int, embedding_dim: int, self.scale_grad_by_freq = scale_grad_by_freq if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx if _weight is None: - self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs)) + self.weight = Parameter( + torch.empty((num_embeddings, embedding_dim), **factory_kwargs) + ) self.reset_parameters() else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" self.weight = Parameter(_weight) self.mode = mode self.sparse = sparse @@ -357,7 +424,12 @@ def _fill_padding_idx_with_zero(self) -> None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor: + def forward( + self, + input: Tensor, + offsets: Optional[Tensor] = None, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: """Forward pass of EmbeddingBag. Args: @@ -388,30 +460,46 @@ def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_we :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. """ - return F.embedding_bag(input, self.weight, offsets, - self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, - self.padding_idx) + return F.embedding_bag( + input, + self.weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}' + s = "{num_embeddings}, {embedding_dim}" if self.max_norm is not None: - s += ', max_norm={max_norm}' + s += ", max_norm={max_norm}" if self.norm_type != 2: - s += ', norm_type={norm_type}' + s += ", norm_type={norm_type}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - s += ', mode={mode}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ", mode={mode}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" return s.format(**{k: repr(v) for k, v in self.__dict__.items()}) @classmethod - def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Optional[float] = None, - norm_type: float = 2., scale_grad_by_freq: bool = False, - mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False, - padding_idx: Optional[int] = None) -> 'EmbeddingBag': + def from_pretrained( + cls, + embeddings: Tensor, + freeze: bool = True, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, + ) -> "EmbeddingBag": r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. Args: @@ -438,8 +526,9 @@ def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Opti >>> embeddingbag(input) tensor([[ 2.5000, 3.7000, 4.6500]]) """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' + assert ( + embeddings.dim() == 2 + ), "Embeddings parameter is expected to be 2-dimensional" rows, cols = embeddings.shape embeddingbag = cls( num_embeddings=rows, @@ -451,6 +540,7 @@ def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Opti mode=mode, sparse=sparse, include_last_offset=include_last_offset, - padding_idx=padding_idx) + padding_idx=padding_idx, + ) embeddingbag.weight.requires_grad = not freeze return embeddingbag diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index f5980cd6b1e8d4..28511d7dd4f1d3 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -1,45 +1,50 @@ # mypy: allow-untyped-defs import copy -from typing import Optional, Any, Union, Callable +import warnings +from typing import Any, Callable, Optional, Union import torch -import warnings +import torch.nn.functional as F from torch import Tensor -from .. import functional as F -from .module import Module +from torch.nn.init import xavier_uniform_ + from .activation import MultiheadAttention from .container import ModuleList -from ..init import xavier_uniform_ from .dropout import Dropout from .linear import Linear +from .module import Module from .normalization import LayerNorm -__all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer'] + +__all__ = [ + "Transformer", + "TransformerEncoder", + "TransformerDecoder", + "TransformerEncoderLayer", + "TransformerDecoderLayer", +] + def _generate_square_subsequent_mask( - sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + sz: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ if device is None: - device = torch.device('cpu') + device = torch.device("cpu") if dtype is None: dtype = torch.float32 return torch.triu( - torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), + torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), diagonal=1, ) -def _get_seq_len( - src: Tensor, - batch_first: bool -) -> Optional[int]: - +def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: if src.is_nested: return None else: @@ -91,33 +96,71 @@ class Transformer(Module): https://github.com/pytorch/examples/tree/master/word_language_model """ - def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, - num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, - layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, - bias: bool = True, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + d_model: int = 512, + nhead: int = 8, + num_encoder_layers: int = 6, + num_decoder_layers: int = 6, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + custom_encoder: Optional[Any] = None, + custom_decoder: Optional[Any] = None, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") if custom_encoder is not None: self.encoder = custom_encoder else: - encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, - activation, layer_norm_eps, batch_first, norm_first, - bias, **factory_kwargs) - encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) - self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + encoder_layer = TransformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + encoder_norm = LayerNorm( + d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + ) + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) if custom_decoder is not None: self.decoder = custom_decoder else: - decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, - activation, layer_norm_eps, batch_first, norm_first, - bias, **factory_kwargs) - decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) - self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + decoder_layer = TransformerDecoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + batch_first, + norm_first, + bias, + **factory_kwargs, + ) + decoder_norm = LayerNorm( + d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs + ) + self.decoder = TransformerDecoder( + decoder_layer, num_decoder_layers, decoder_norm + ) self._reset_parameters() @@ -126,11 +169,20 @@ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = self.batch_first = batch_first - def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, - src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None, - memory_is_causal: bool = False) -> Tensor: + def forward( + self, + src: Tensor, + tgt: Tensor, + src_mask: Optional[Tensor] = None, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + src_is_causal: Optional[bool] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + ) -> Tensor: r"""Take in and process masked source/target sequences. .. note:: @@ -213,21 +265,33 @@ def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, t raise RuntimeError("the batch number of src and tgt must be equal") if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: - raise RuntimeError("the feature number of src and tgt must be equal to d_model") - - memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask, - is_causal=src_is_causal) - output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal) + raise RuntimeError( + "the feature number of src and tgt must be equal to d_model" + ) + + memory = self.encoder( + src, + mask=src_mask, + src_key_padding_mask=src_key_padding_mask, + is_causal=src_is_causal, + ) + output = self.decoder( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) return output @staticmethod def generate_square_subsequent_mask( - sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + sz: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. @@ -262,7 +326,7 @@ class TransformerEncoder(Module): >>> out = transformer_encoder(src) """ - __constants__ = ['norm'] + __constants__ = ["norm"] def __init__( self, @@ -270,7 +334,7 @@ def __init__( num_layers: int, norm: Optional[Module] = None, enable_nested_tensor: bool = True, - mask_check: bool = True + mask_check: bool = True, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -284,36 +348,46 @@ def __init__( self.mask_check = mask_check enc_layer = "encoder_layer" - why_not_sparsity_fast_path = '' + why_not_sparsity_fast_path = "" if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer): why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer" - elif encoder_layer.norm_first : + elif encoder_layer.norm_first: why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True" elif not encoder_layer.self_attn.batch_first: - why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" + - "(use batch_first for better inference performance)") + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn.batch_first was not True" + + "(use batch_first for better inference performance)" + ) elif not encoder_layer.self_attn._qkv_same_embed_dim: - why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" + why_not_sparsity_fast_path = ( + f"{enc_layer}.self_attn._qkv_same_embed_dim was not True" + ) elif encoder_layer.self_attn.in_proj_bias is None: why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False" elif not encoder_layer.activation_relu_or_gelu: - why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True" - elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) : - why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" + why_not_sparsity_fast_path = ( + f"{enc_layer}.activation_relu_or_gelu was not True" + ) + elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps): + why_not_sparsity_fast_path = ( + f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps" + ) elif encoder_layer.self_attn.num_heads % 2 == 1: why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd" if enable_nested_tensor and why_not_sparsity_fast_path: - warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}") + warnings.warn( + f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}" + ) self.use_nested_tensor = False - def forward( - self, - src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - is_causal: Optional[bool] = None) -> Tensor: + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: Optional[bool] = None, + ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: @@ -336,7 +410,7 @@ def forward( mask_name="src_key_padding_mask", other_type=F._none_or_dtype(mask), other_name="mask", - target_type=src.dtype + target_type=src.dtype, ) mask = F._canonical_mask( @@ -352,30 +426,41 @@ def forward( convert_to_nested = False first_layer = self.layers[0] src_key_padding_mask_for_layers = src_key_padding_mask - why_not_sparsity_fast_path = '' + why_not_sparsity_fast_path = "" str_first_layer = "self.layers[0]" batch_first = first_layer.self_attn.batch_first is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() if not is_fastpath_enabled: - why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) elif not hasattr(self, "use_nested_tensor"): why_not_sparsity_fast_path = "use_nested_tensor attribute not present" elif not self.use_nested_tensor: - why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True" + why_not_sparsity_fast_path = ( + "self.use_nested_tensor (set in init) was not True" + ) elif first_layer.training: why_not_sparsity_fast_path = f"{str_first_layer} was in training mode" elif not src.dim() == 3: - why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) elif src_key_padding_mask is None: why_not_sparsity_fast_path = "src_key_padding_mask was None" - elif (((not hasattr(self, "mask_check")) or self.mask_check) - and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())): + elif ( + (not hasattr(self, "mask_check")) or self.mask_check + ) and not torch._nested_tensor_from_mask_left_aligned( + src, src_key_padding_mask.logical_not() + ): why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned" elif output.is_nested: why_not_sparsity_fast_path = "NestedTensor input is not supported" elif mask is not None: - why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied" + why_not_sparsity_fast_path = ( + "src_key_padding_mask and mask were both supplied" + ) elif torch.is_autocast_enabled(): why_not_sparsity_fast_path = "autocast is enabled" @@ -395,28 +480,43 @@ def forward( first_layer.linear2.weight, first_layer.linear2.bias, ) - _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + _supported_device_type = [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] if torch.overrides.has_torch_function(tensor_args): why_not_sparsity_fast_path = "some Tensor argument has_torch_function" elif src.device.type not in _supported_device_type: - why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}" + why_not_sparsity_fast_path = ( + f"src device is neither one of {_supported_device_type}" + ) elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): - why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None): convert_to_nested = True - output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False) + output = torch._nested_tensor_from_mask( + output, src_key_padding_mask.logical_not(), mask_check=False + ) src_key_padding_mask_for_layers = None seq_len = _get_seq_len(src, batch_first) is_causal = _detect_is_causal_mask(mask, is_causal, seq_len) for mod in self.layers: - output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers) + output = mod( + output, + src_mask=mask, + is_causal=is_causal, + src_key_padding_mask=src_key_padding_mask_for_layers, + ) if convert_to_nested: - output = output.to_padded_tensor(0., src.size()) + output = output.to_padded_tensor(0.0, src.size()) if self.norm is not None: output = self.norm(output) @@ -440,13 +540,13 @@ class TransformerDecoder(Module): >>> out = transformer_decoder(tgt, memory) """ - __constants__ = ['norm'] + __constants__ = ["norm"] def __init__( self, decoder_layer: "TransformerDecoderLayer", num_layers: int, - norm: Optional[Module] = None + norm: Optional[Module] = None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -454,10 +554,17 @@ def __init__( self.num_layers = num_layers self.norm = norm - def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, - memory_is_causal: bool = False) -> Tensor: + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: @@ -492,18 +599,23 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len) for mod in self.layers: - output = mod(output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - tgt_is_causal=tgt_is_causal, - memory_is_causal=memory_is_causal) + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + tgt_is_causal=tgt_is_causal, + memory_is_causal=memory_is_causal, + ) if self.norm is not None: output = self.norm(output) return output + class TransformerEncoderLayer(Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. @@ -579,17 +691,32 @@ class TransformerEncoderLayer(Module): """ - __constants__ = ['norm_first'] + __constants__ = ["norm_first"] - def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, - bias: bool = True, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, - bias=bias, batch_first=batch_first, - **factory_kwargs) + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + bias=bias, + batch_first=batch_first, + **factory_kwargs, + ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) self.dropout = Dropout(dropout) @@ -617,16 +744,16 @@ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropou def __setstate__(self, state): super().__setstate__(state) - if not hasattr(self, 'activation'): + if not hasattr(self, "activation"): self.activation = F.relu - def forward( - self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - is_causal: bool = False) -> Tensor: + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + is_causal: bool = False, + ) -> Tensor: r"""Pass the input through the encoder layer. Args: @@ -649,7 +776,7 @@ def forward( mask_name="src_key_padding_mask", other_type=F._none_or_dtype(src_mask), other_name="src_mask", - target_type=src.dtype + target_type=src.dtype, ) src_mask = F._canonical_mask( @@ -663,11 +790,15 @@ def forward( is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled() - why_not_sparsity_fast_path = '' + why_not_sparsity_fast_path = "" if not is_fastpath_enabled: - why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True" + why_not_sparsity_fast_path = ( + "torch.backends.mha.get_fastpath_enabled() was not True" + ) elif not src.dim() == 3: - why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" + why_not_sparsity_fast_path = ( + f"input not batched; expected src.dim() of 3 but got {src.dim()}" + ) elif self.training: why_not_sparsity_fast_path = "training is enabled" elif not self.self_attn.batch_first: @@ -680,7 +811,9 @@ def forward( why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" elif not (self.norm1.eps == self.norm2.eps): why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" - elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None): + elif src.is_nested and ( + src_key_padding_mask is not None or src_mask is not None + ): why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" elif self.self_attn.num_heads % 2 == 1: why_not_sparsity_fast_path = "num_head is odd" @@ -705,18 +838,30 @@ def forward( # We have to use list comprehensions below because TorchScript does not support # generator expressions. - _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name] + _supported_device_type = [ + "cpu", + "cuda", + torch.utils.backend_registration._privateuse1_backend_name, + ] if torch.overrides.has_torch_function(tensor_args): why_not_sparsity_fast_path = "some Tensor argument has_torch_function" - elif not all((x.device.type in _supported_device_type) for x in tensor_args): - why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of " - f"{_supported_device_type}") + elif not all( + (x.device.type in _supported_device_type) for x in tensor_args + ): + why_not_sparsity_fast_path = ( + "some Tensor argument's device is neither one of " + f"{_supported_device_type}" + ) elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): - why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") + why_not_sparsity_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) if not why_not_sparsity_fast_path: - merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = self.self_attn.merge_masks( + src_mask, src_key_padding_mask, src + ) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, @@ -743,21 +888,36 @@ def forward( # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = src if self.norm_first: - x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) + x = x + self._sa_block( + self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal + ) x = x + self._ff_block(self.norm2(x)) else: - x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) + x = self.norm1( + x + + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal) + ) x = self.norm2(x + self._ff_block(x)) return x # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: - x = self.self_attn(x, x, x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False, is_causal=is_causal)[0] + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=is_causal, + )[0] return self.dropout1(x) # feed forward block @@ -804,18 +964,40 @@ class TransformerDecoderLayer(Module): >>> out = decoder_layer(tgt, memory) """ - __constants__ = ['norm_first'] + __constants__ = ["norm_first"] - def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, - bias: bool = True, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = False, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, - bias=bias, **factory_kwargs) - self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, - bias=bias, **factory_kwargs) + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + bias=bias, + **factory_kwargs, + ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) self.dropout = Dropout(dropout) @@ -836,8 +1018,8 @@ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropou self.activation = activation def __setstate__(self, state): - if 'activation' not in state: - state['activation'] = F.relu + if "activation" not in state: + state["activation"] = F.relu super().__setstate__(state) def forward( @@ -883,34 +1065,68 @@ def forward( x = tgt if self.norm_first: - x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal) - x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal) + x = x + self._sa_block( + self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal + ) + x = x + self._mha_block( + self.norm2(x), + memory, + memory_mask, + memory_key_padding_mask, + memory_is_causal, + ) x = x + self._ff_block(self.norm3(x)) else: - x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)) - x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)) + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask, memory_is_causal + ) + ) x = self.norm3(x + self._ff_block(x)) return x # self-attention block - def _sa_block(self, x: Tensor, - attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: - x = self.self_attn(x, x, x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - need_weights=False)[0] + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] return self.dropout1(x) # multihead attention block - def _mha_block(self, x: Tensor, mem: Tensor, - attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: - x = self.multihead_attn(x, mem, mem, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - need_weights=False)[0] + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + is_causal: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + need_weights=False, + )[0] return self.dropout2(x) # feed forward block @@ -934,9 +1150,9 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: def _detect_is_causal_mask( - mask: Optional[Tensor], - is_causal: Optional[bool] = None, - size: Optional[int] = None, + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, ) -> bool: """Return whether the given attention mask is causal. @@ -958,12 +1174,13 @@ def _detect_is_causal_mask( Otherwise, checks for any causal mask. """ # Prevent type refinement - make_causal = (is_causal is True) + make_causal = is_causal is True if is_causal is None and mask is not None: sz = size if size is not None else mask.size(-2) causal_comparison = _generate_square_subsequent_mask( - sz, device=mask.device, dtype=mask.dtype) + sz, device=mask.device, dtype=mask.dtype + ) # Do not use `torch.equal` so we handle batched masks by # broadcasting the comparison. diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py index 7d674da0d5c3ee..4943840152868a 100644 --- a/torch/nn/modules/upsampling.py +++ b/torch/nn/modules/upsampling.py @@ -1,12 +1,14 @@ # mypy: allow-untyped-defs -from .module import Module -from .. import functional as F +from typing import Optional +import torch.nn.functional as F from torch import Tensor -from typing import Optional -from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t +from torch.nn.common_types import _ratio_2_t, _ratio_any_t, _size_2_t, _size_any_t + +from .module import Module + -__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] +__all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"] class Upsample(Module): @@ -132,7 +134,14 @@ class Upsample(Module): [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) """ - __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor'] + __constants__ = [ + "size", + "scale_factor", + "mode", + "align_corners", + "name", + "recompute_scale_factor", + ] name: str size: Optional[_size_any_t] scale_factor: Optional[_ratio_any_t] @@ -140,9 +149,14 @@ class Upsample(Module): align_corners: Optional[bool] recompute_scale_factor: Optional[bool] - def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None, - mode: str = 'nearest', align_corners: Optional[bool] = None, - recompute_scale_factor: Optional[bool] = None) -> None: + def __init__( + self, + size: Optional[_size_any_t] = None, + scale_factor: Optional[_ratio_any_t] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + ) -> None: super().__init__() self.name = type(self).__name__ self.size = size @@ -155,21 +169,27 @@ def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_r self.recompute_scale_factor = recompute_scale_factor def forward(self, input: Tensor) -> Tensor: - return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners, - recompute_scale_factor=self.recompute_scale_factor) + return F.interpolate( + input, + self.size, + self.scale_factor, + self.mode, + self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + ) def __setstate__(self, state): - if 'recompute_scale_factor' not in state: - state['recompute_scale_factor'] = True + if "recompute_scale_factor" not in state: + state["recompute_scale_factor"] = True super().__setstate__(state) def extra_repr(self) -> str: if self.scale_factor is not None: - info = 'scale_factor=' + repr(self.scale_factor) + info = "scale_factor=" + repr(self.scale_factor) else: - info = 'size=' + repr(self.size) - info += ', mode=' + repr(self.mode) + info = "size=" + repr(self.size) + info += ", mode=" + repr(self.mode) return info @@ -214,8 +234,12 @@ class UpsamplingNearest2d(Upsample): [3., 3., 4., 4.]]]]) """ - def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: - super().__init__(size, scale_factor, mode='nearest') + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super().__init__(size, scale_factor, mode="nearest") class UpsamplingBilinear2d(Upsample): @@ -261,5 +285,9 @@ class UpsamplingBilinear2d(Upsample): [3.0000, 3.3333, 3.6667, 4.0000]]]]) """ - def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: - super().__init__(size, scale_factor, mode='bilinear', align_corners=True) + def __init__( + self, + size: Optional[_size_2_t] = None, + scale_factor: Optional[_ratio_2_t] = None, + ) -> None: + super().__init__(size, scale_factor, mode="bilinear", align_corners=True) diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 4a051ed1eba5b1..767861dbc6cd52 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs import collections from itertools import repeat -from typing import List, Dict, Any +from typing import Any, Dict, List -__all__ = ['consume_prefix_in_state_dict_if_present'] + +__all__ = ["consume_prefix_in_state_dict_if_present"] def _ntuple(n, name="parse"): @@ -33,19 +34,19 @@ def _reverse_repeat_tuple(t, n): def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: import torch + if isinstance(out_size, (int, torch.SymInt)): return out_size if len(defaults) <= len(out_size): - raise ValueError( - f"Input dimension should be at least {len(out_size) + 1}" - ) + raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") return [ v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) ] def consume_prefix_in_state_dict_if_present( - state_dict: Dict[str, Any], prefix: str + state_dict: Dict[str, Any], + prefix: str, ) -> None: r"""Strip the prefix in state_dict in place, if any. @@ -75,6 +76,6 @@ def consume_prefix_in_state_dict_if_present( if len(key) == 0: continue # handling both, 'module' case and 'module.' cases - if key == prefix.replace('.', '') or key.startswith(prefix): + if key == prefix.replace(".", "") or key.startswith(prefix): newkey = key[len(prefix) :] state_dict._metadata[newkey] = state_dict._metadata.pop(key) From f6e6e55fa7d883a89ba99584f8632c260519ba73 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Jun 2024 03:58:20 +0800 Subject: [PATCH 083/171] [BE] enable UFMT for `torch/nn/functional.py` (#128592) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128592 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #128596, #128594 --- .lintrunner.toml | 1 - tools/pyi/gen_pyi.py | 8 +- torch/__init__.py | 1 + torch/_jit_internal.py | 33 +- torch/_library/utils.py | 16 +- torch/_linalg_utils.py | 12 +- torch/_lowrank.py | 10 +- torch/_meta_registrations.py | 103 ++- torch/_utils.py | 21 +- torch/_utils_internal.py | 5 +- torch/_vmap_internals.py | 12 +- torch/distributed/__init__.py | 9 +- torch/functional.py | 40 +- torch/hub.py | 24 +- torch/library.py | 6 +- torch/nn/functional.py | 1523 ++++++++++++++++++++++++--------- torch/overrides.py | 14 +- torch/serialization.py | 18 +- 18 files changed, 1362 insertions(+), 494 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2093f258849019..ac5bbae1302170 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1647,7 +1647,6 @@ exclude_patterns = [ 'torch/nn/_reduction.py', 'torch/nn/common_types.py', 'torch/nn/cpp.py', - 'torch/nn/functional.py', 'torch/nn/grad.py', 'torch/nn/init.py', 'torch/nn/intrinsic/__init__.py', diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index a021d702be410d..c218db8e1d0f49 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -505,7 +505,7 @@ def gen_nn_functional(fm: FileManager) -> None: "pdist", "cosine_similarity", ] - imported_hints = [f"from .. import {_} as {_}" for _ in torch_imports] + imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports] # Functions imported into `torch.nn.functional` from `torch._C._nn` c_nn_imports = [ @@ -522,9 +522,11 @@ def gen_nn_functional(fm: FileManager) -> None: "one_hot", "scaled_dot_product_attention", ] - imported_hints += [f"from .._C._nn import {_} as {_}" for _ in c_nn_imports] + imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports] # This is from `torch._C._nn` but renamed - imported_hints.append("from .._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid") + imported_hints.append( + "from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid" + ) # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` unsorted_dispatched_hints: Dict[str, List[str]] = {} diff --git a/torch/__init__.py b/torch/__init__.py index 8be97423c43fc7..3253adcfb04296 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2007,6 +2007,7 @@ def _assert(condition, message): backends as backends, cpu as cpu, cuda as cuda, + distributed as distributed, distributions as distributions, fft as fft, futures as futures, diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4ed425f0435a6e..f0bef92757760c 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -14,26 +14,24 @@ import io import pickle import sys +import textwrap import threading import types import typing import warnings import weakref -from textwrap import dedent -from typing import ( # noqa: F401 +from typing import ( Any, Callable, Dict, Final, ForwardRef, - Generic, - get_args, # new in 3.8 - get_origin, # new in 3.8 + get_args, + get_origin, List, Optional, Tuple, Type, - TypeVar, Union, ) @@ -42,7 +40,7 @@ # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. # Explicitly ask to import `torch.distributed.__init__` first. # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. -import torch.distributed.rpc +import torch.distributed as dist import torch.package._mangling as package_mangling from torch._awaits import _Await from torch._C import _Await as CAwait, Future as CFuture @@ -405,7 +403,7 @@ def get_type_hint_captures(fn): # by source inspection. This accounts for the case in which aliases are used # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. - a = ast.parse(dedent(src)) + a = ast.parse(textwrap.dedent(src)) if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): raise RuntimeError(f"Expected {fn} to be a function") f = a.body[0] @@ -482,7 +480,13 @@ def lookup_in_class(key): def boolean_dispatch( - arg_name, arg_index, default, if_true, if_false, module_name, func_name + arg_name, + arg_index, + default, + if_true, + if_false, + module_name, + func_name, ): """ Dispatches to either of 2 script functions based on a boolean argument. @@ -1083,7 +1087,7 @@ def is_await(ann) -> bool: return get_origin(ann) is _Await -if torch.distributed.rpc.is_available(): +if dist.rpc.is_available(): from torch._C._distributed_rpc import PyRRef from torch.distributed.rpc import RRef @@ -1227,7 +1231,9 @@ def _try_get_dispatched_fn(fn): def _get_named_tuple_properties( - obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None + obj, + loc: Optional[torch._C._jit_tree_views.SourceRange] = None, + rcb=None, ): if loc is None: loc = fake_range() @@ -1307,7 +1313,10 @@ def _get_named_tuple_properties( def _create_named_tuple( - t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...] + t, + unqual_name: str, + field_names: List[str], + defaults: Tuple[Any, ...], ): TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] return TupleType(*t) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 27d1ef92b5b3dc..e9bf3ed4e4a7bc 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -5,8 +5,8 @@ from typing import Any, Callable, Dict, Iterable, Tuple import torch -import torch._utils_internal as _utils_internal -from torch import _C +from torch import _C, _utils_internal +from torch._ops import OpOverload @dataclasses.dataclass @@ -56,7 +56,7 @@ def parse_namespace(qualname: str) -> Tuple[str, str]: return splits[0], splits[1] -def lookup_op(qualname: str) -> torch._ops.OpOverload: +def lookup_op(qualname: str) -> OpOverload: namespace, name = parse_namespace(qualname) if "." in name: name, overload = name.split(".") @@ -67,8 +67,8 @@ def lookup_op(qualname: str) -> torch._ops.OpOverload: return getattr(packet, overload) -def is_builtin(op: torch._ops.OpOverload) -> bool: - assert isinstance(op, torch._ops.OpOverload) +def is_builtin(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) return op.namespace in {"aten", "prim", "prims"} @@ -121,7 +121,7 @@ def is_tensor_like_type(typ: Any) -> bool: return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) -def mutates_and_returns_first_arg(op: torch._ops.OpOverload): +def mutates_and_returns_first_arg(op: OpOverload): """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, @@ -201,8 +201,8 @@ def zip_schema( return -def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool: - assert isinstance(op, torch._ops.OpOverload) +def can_generate_trivial_fake_impl(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) if is_builtin(op): # We control the built-ins. These may (in rare cases) # do input metadata mutation (which we have banned on custom ops) diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index fd5f574ad7eb07..ae5c0e5d39cd2e 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -111,7 +111,11 @@ def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: def _symeig( - input, eigenvectors=False, upper=True, *, out=None + input, + eigenvectors=False, + upper=True, + *, + out=None, ) -> Tuple[Tensor, Tensor]: raise RuntimeError( "This function was deprecated since version 1.9 and is now removed. " @@ -128,7 +132,11 @@ def _symeig( def eig( - self: Tensor, eigenvectors: bool = False, *, e=None, v=None + self: Tensor, + eigenvectors: bool = False, + *, + e=None, + v=None, ) -> Tuple[Tensor, Tensor]: raise RuntimeError( "This function was deprecated since version 1.9 and is now removed. " diff --git a/torch/_lowrank.py b/torch/_lowrank.py index bbe01ede68aeb6..994632f818dbae 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -11,7 +11,10 @@ def get_approximate_basis( - A: Tensor, q: int, niter: Optional[int] = 2, M: Optional[Tensor] = None + A: Tensor, + q: int, + niter: Optional[int] = 2, + M: Optional[Tensor] = None, ) -> Tensor: """Return tensor :math:`Q` with :math:`q` orthonormal columns such that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is @@ -180,7 +183,10 @@ def _svd_lowrank( def pca_lowrank( - A: Tensor, q: Optional[int] = None, center: bool = True, niter: int = 2 + A: Tensor, + q: Optional[int] = None, + center: bool = True, + niter: int = 2, ) -> Tuple[Tensor, Tensor, Tensor]: r"""Performs linear Principal Component Analysis (PCA) on a low-rank matrix, batches of such matrices, or sparse matrix. diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 0abebc58cea897..20b71d7f1e062d 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -311,7 +311,12 @@ def meta_randperm(n, *, generator=None, out): @register_meta(aten.randperm.default) def meta_randperm_default( - n, *, dtype=torch.long, layout=None, device=None, pin_memory=None + n, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, ): return torch.empty( n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory @@ -321,7 +326,13 @@ def meta_randperm_default( @register_meta([aten.randint.default, aten.randint.out]) @out_wrapper() def meta_randint( - high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None + high, + size, + *, + dtype=torch.long, + layout=None, + device=None, + pin_memory=None, ): return torch.empty( size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory @@ -775,7 +786,9 @@ def linearSolveCheckInputs( # From aten/src/ATen/native/LinearAlgebraUtils.h def checkFloatingOrComplex( - t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True + t: Tensor, + f_name: str, + allow_low_precision_dtypes: bool = True, ): dtype = t.dtype torch._check( @@ -816,7 +829,10 @@ def checkInputsSolver( def checkSameDevice( - fn_name: str, result: Tensor, input: Tensor, result_name: str = "result" + fn_name: str, + result: Tensor, + input: Tensor, + result_name: str = "result", ): torch._check( result.device == input.device, @@ -1036,7 +1052,11 @@ def linalg_ldl_factor_ex_meta( @register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out]) @out_wrapper() def linalg_ldl_solve_meta( - LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False + LD: Tensor, + pivots: Tensor, + B: Tensor, + *, + hermitian: bool = False, ) -> Tensor: squareCheckInputs(LD, "torch.linalg.ldl_solve") checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") @@ -1104,7 +1124,10 @@ def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Te @register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out]) @out_wrapper("LU", "pivots", "info") def linalg_lu_factor_ex_meta( - A: Tensor, *, pivot: bool = True, check_errors: bool = False + A: Tensor, + *, + pivot: bool = True, + check_errors: bool = False, ) -> Tuple[Tensor, Tensor, Tensor]: torch._check( A.ndim >= 2, @@ -1344,7 +1367,8 @@ def _linalg_svd_meta( def _linalg_broadcast_batch_dims( - arg1: Tensor, arg2: Tensor + arg1: Tensor, + arg2: Tensor, ) -> Tuple[List[int], List[int]]: # broadcast the batch dimensions of arg1 and arg2. arg1_batch_sizes = arg1.shape[:-2] @@ -1360,7 +1384,9 @@ def _linalg_broadcast_batch_dims( def _linalg_broadcast_batch_dims_name( - arg1: Tensor, arg2: Tensor, name: Optional[str] + arg1: Tensor, + arg2: Tensor, + name: Optional[str], ) -> Tuple[Tensor, Tensor]: # If there's no name we assume we don't want to check the errors if name: @@ -3722,7 +3748,13 @@ def div_rtn(x, y): def pooling_output_shape_pad_lr( - inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode + inputSize, + kernelSize, + pad_l, + pad_r, + stride, + dilation, + ceil_mode, ): outputSize = ( div_rtn( @@ -4029,7 +4061,12 @@ def avg_pool3d_backward_shape_check( def max_pool2d_checks_and_compute_shape( - input, kernel_size, stride, padding, dilation, ceil_mode + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode, ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): @@ -4145,7 +4182,12 @@ def _check_dim_size(t): @register_meta(aten.max_pool2d_with_indices.default) def meta_max_pool2d_with_indices( - input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False + input, + kernel_size, + stride=(), + padding=(0,), + dilation=(1,), + ceil_mode=False, ): ( nInputPlane, @@ -5516,7 +5558,12 @@ def meta_argsort(self, *, stable, dim=-1, descending=False): def rnn_cell_checkSizes( - input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden + input_gates, + hidden_gates, + input_bias, + hidden_bias, + factor, + prev_hidden, ): torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2") torch._check( @@ -5551,7 +5598,11 @@ def rnn_cell_checkSizes( @register_meta(aten._thnn_fused_lstm_cell.default) def _thnn_fused_lstm_cell_meta( - input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None + input_gates, + hidden_gates, + cx, + input_bias=None, + hidden_bias=None, ): rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx) workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format) @@ -5862,7 +5913,11 @@ def meta_histc(input, bins=100, min=0, max=0): [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default] ) def meta_upsample_bimode2d_aa( - input, output_size, align_corners, scales_h=None, scales_w=None + input, + output_size, + align_corners, + scales_h=None, + scales_w=None, ): full_output_size = upsample_common_check( input.size(), output_size, num_spatial_dims=2 @@ -5951,7 +6006,13 @@ def t_(self): @register_meta(aten.searchsorted) @out_wrapper() def meta_searchsorted( - sorted_sequence, self, *, out_int32=False, right=False, side=None, sorter=None + sorted_sequence, + self, + *, + out_int32=False, + right=False, + side=None, + sorter=None, ): dtype = torch.int32 if out_int32 else torch.int64 if isinstance(self, torch.Tensor): @@ -5993,7 +6054,13 @@ def meta_embedding_bag_dense_backward( @register_meta(aten._embedding_bag_per_sample_weights_backward) def meta_embedding_bag_per_sample_weights_backward( - grad, weight, indices, offsets, offset2bag, mode, padding_idx=-1 + grad, + weight, + indices, + offsets, + offset2bag, + mode, + padding_idx=-1, ): MODE_SUM, MODE_MEAN, MODE_MAX = range(3) embedding_features = grad.size(1) @@ -6079,7 +6146,9 @@ def meta__jagged_to_padded_dense_forward( @register_meta(aten._padded_dense_to_jagged_forward.default) def meta__padded_dense_to_jagged_forward( - padded: Tensor, offsets: List[Tensor], total_L: Optional[int] = None + padded: Tensor, + offsets: List[Tensor], + total_L: Optional[int] = None, ): # only one jagged dim is supported for now assert len(offsets) == 1 diff --git a/torch/_utils.py b/torch/_utils.py index 5096b62618df01..27f62f21f07543 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -195,7 +195,13 @@ def set_tensor_metadata(tensor, metadata): def _rebuild_tensor_v2( - storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None + storage, + storage_offset, + size, + stride, + requires_grad, + backward_hooks, + metadata=None, ): tensor = _rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad @@ -346,7 +352,14 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): def _rebuild_wrapper_subclass( - cls, dtype, size, stride, storage_offset, layout, device, requires_grad + cls, + dtype, + size, + stride, + storage_offset, + layout, + device, + requires_grad, ): device = _get_restore_location(device) return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] @@ -762,7 +775,9 @@ def get_current_device_index() -> int: def _get_device_index( - device: Any, optional: bool = False, allow_cpu: bool = False + device: Any, + optional: bool = False, + allow_cpu: bool = False, ) -> int: r"""Gets the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``. diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 0001888f18edc3..2d8dabbd28dbdf 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -133,7 +133,10 @@ def check_if_torch_exportable(): def log_torch_jit_trace_exportability( - api: str, type_of_export: str, export_outcome: str, result: str + api: str, + type_of_export: str, + export_outcome: str, + result: str, ): _, _, _, _ = api, type_of_export, export_outcome, result return diff --git a/torch/_vmap_internals.py b/torch/_vmap_internals.py index cc23d7851eb553..666cd7a84cc5f9 100644 --- a/torch/_vmap_internals.py +++ b/torch/_vmap_internals.py @@ -13,7 +13,8 @@ # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( - flat_in_dims: List[Optional[int]], flat_args: List + flat_in_dims: List[Optional[int]], + flat_args: List, ) -> int: batch_sizes = [ arg.size(in_dim) @@ -37,7 +38,9 @@ def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: # If value is a tuple, check it has length `num_elements`. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times def _as_tuple( - value: Any, num_elements: int, error_message_lambda: Callable[[], str] + value: Any, + num_elements: int, + error_message_lambda: Callable[[], str], ) -> Tuple: if not isinstance(value, tuple): return (value,) * num_elements @@ -49,7 +52,10 @@ def _as_tuple( # Creates BatchedTensors for every Tensor in arg that should be batched. # Returns the (potentially) batched arguments and the batch_size. def _create_batched_inputs( - in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable + in_dims: in_dims_t, + args: Tuple, + vmap_level: int, + func: Callable, ) -> Tuple[Tuple, int]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index b8e911c8738c24..eb339000e89e7e 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,9 +1,6 @@ # mypy: allow-untyped-defs -import os import sys -from enum import Enum import pdb -import io import torch @@ -143,4 +140,8 @@ def breakpoint(rank: int = 0): class _ProcessGroupStub: pass - sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] + + ProcessGroup = _ProcessGroupStub # type: ignore[misc,assignment] + + +from torch.distributed import rpc as rpc diff --git a/torch/functional.py b/torch/functional.py index 20e1cf1faf7412..d7311f56954d7f 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -153,7 +153,9 @@ def broadcast_shapes(*shapes): def split( - tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0 + tensor: Tensor, + split_size_or_sections: Union[int, List[int]], + dim: int = 0, ) -> Tuple[Tensor, ...]: r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. @@ -1043,7 +1045,11 @@ def _unique_consecutive_impl( def _return_counts( - input, sorted=True, return_inverse=False, return_counts=False, dim=None + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] @@ -1055,7 +1061,11 @@ def _return_counts( def _return_output( - input, sorted=True, return_inverse=False, return_counts=False, dim=None + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor @@ -1067,7 +1077,11 @@ def _return_output( def _return_inverse( - input, sorted=True, return_inverse=False, return_counts=False, dim=None + input, + sorted=True, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] @@ -1116,7 +1130,10 @@ def _return_inverse( def _consecutive_return_counts( - input, return_inverse=False, return_counts=False, dim=None + input, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] @@ -1130,7 +1147,10 @@ def _consecutive_return_counts( def _consecutive_return_output( - input, return_inverse=False, return_counts=False, dim=None + input, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, Optional[int]) -> Tensor @@ -1142,7 +1162,10 @@ def _consecutive_return_output( def _consecutive_return_inverse( - input, return_inverse=False, return_counts=False, dim=None + input, + return_inverse=False, + return_counts=False, + dim=None, ): # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] @@ -1866,7 +1889,8 @@ def norm( # noqa: F811 def unravel_index( - indices: Tensor, shape: Union[int, Sequence[int], torch.Size] + indices: Tensor, + shape: Union[int, Sequence[int], torch.Size], ) -> Tuple[Tensor, ...]: r"""Converts a tensor of flat indices into a tuple of coordinate tensors that index into an arbitrary tensor of the specified shape. diff --git a/torch/hub.py b/torch/hub.py index 57a07e2db4be0c..092032a921b027 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -215,7 +215,12 @@ def _validate_not_a_forked_repo(repo_owner, repo_name, ref): def _get_cache_or_reload( - github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False + github, + force_reload, + trust_repo, + calling_fn, + verbose=True, + skip_validation=False, ): # Setup hub_dir to save downloaded files hub_dir = get_dir() @@ -294,7 +299,11 @@ def _get_cache_or_reload( def _check_repo_is_trusted( - repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load" + repo_owner, + repo_name, + owner_name_branch, + trust_repo, + calling_fn="load", ): hub_dir = get_dir() filepath = os.path.join(hub_dir, "trusted_list") @@ -411,7 +420,11 @@ def set_dir(d): def list( - github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True + github, + force_reload=False, + skip_validation=False, + trust_repo=None, + verbose=True, ): r""" List all callable entrypoints available in the repo specified by ``github``. @@ -664,7 +677,10 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file( - url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True + url: str, + dst: str, + hash_prefix: Optional[str] = None, + progress: bool = True, ) -> None: r"""Download object at the given URL to a local path. diff --git a/torch/library.py b/torch/library.py index 319041d13b1210..e9884bd6f8fd33 100644 --- a/torch/library.py +++ b/torch/library.py @@ -330,7 +330,11 @@ def _destroy(self): def _del_library( - captured_impls, op_impls, captured_defs, op_defs, registration_handles + captured_impls, + op_impls, + captured_defs, + op_defs, + registration_handles, ): captured_impls -= op_impls captured_defs -= op_defs diff --git a/torch/nn/functional.py b/torch/nn/functional.py index f67e2ddee04ac9..f756f9e20b8879 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1,37 +1,42 @@ """Functional interface.""" -from typing import Callable, List, Optional, Tuple, Union + +import importlib import math import warnings -import importlib - -try: - import numpy as np -except ModuleNotFoundError: - np = None +from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch import _VF -from torch import sym_int as _sym_int -from torch._C import _infer_size, _add_docstr -from torch._torch_docs import reproducibility_notes, tf32_notes, sparse_support_notes -# A workaround to support both TorchScript and MyPy: -from typing import TYPE_CHECKING +from torch import _VF, sym_int as _sym_int, Tensor +from torch._C import _add_docstr, _infer_size +from torch._jit_internal import ( + _overload, + boolean_dispatch, + BroadcastingList1, + BroadcastingList2, + BroadcastingList3, +) +from torch._torch_docs import reproducibility_notes, sparse_support_notes, tf32_notes +from torch.nn import _reduction as _Reduction, grad # noqa: F401 +from torch.nn.modules.utils import _list_with_default, _pair, _single, _triple +from torch.overrides import ( + handle_torch_function, + has_torch_function, + has_torch_function_unary, + has_torch_function_variadic, +) + + if TYPE_CHECKING: from torch.types import _dtype as DType else: # The JIT doesn't understand Union, nor torch.dtype here DType = int -from .._jit_internal import boolean_dispatch, _overload, BroadcastingList1, BroadcastingList2, BroadcastingList3 -from ..overrides import ( - has_torch_function, has_torch_function_unary, has_torch_function_variadic, - handle_torch_function) -from . import _reduction as _Reduction -from . import grad # noqa: F401 -from .modules import utils -from .modules.utils import _single, _pair, _triple, _list_with_default +try: + import numpy as np +except ModuleNotFoundError: + np = None -Tensor = torch.Tensor conv1d = _add_docstr( torch.conv1d, @@ -429,11 +434,12 @@ def fractional_max_pool2d_with_indices( - input: Tensor, kernel_size: BroadcastingList2[int], + input: Tensor, + kernel_size: BroadcastingList2[int], output_size: Optional[BroadcastingList2[int]] = None, output_ratio: Optional[BroadcastingList2[float]] = None, return_indices: bool = False, - _random_samples: Optional[Tensor] = None + _random_samples: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) @@ -479,26 +485,38 @@ def fractional_max_pool2d_with_indices( _random_samples=_random_samples, ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool2d requires specifying either an output_size or an output_ratio") + raise ValueError( + "fractional_max_pool2d requires specifying either an output_size or an output_ratio" + ) if output_size is None: assert output_ratio is not None if len(output_ratio) > 2: - raise ValueError("fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints.") + raise ValueError( + "fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints." + ) _output_ratio = _pair(output_ratio) - output_size = [int(input.size(-2) * _output_ratio[0]), int(input.size(-1) * _output_ratio[1])] + output_size = [ + int(input.size(-2) * _output_ratio[0]), + int(input.size(-1) * _output_ratio[1]), + ] if _random_samples is None: n_batch = 1 if input.dim() == 3 else input.size(0) - _random_samples = torch.rand(n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device) - return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) + _random_samples = torch.rand( + n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool2d( + input, kernel_size, output_size, _random_samples + ) def _fractional_max_pool2d( - input: Tensor, kernel_size: BroadcastingList2[int], + input: Tensor, + kernel_size: BroadcastingList2[int], output_size: Optional[BroadcastingList2[int]] = None, output_ratio: Optional[BroadcastingList2[float]] = None, return_indices: bool = False, - _random_samples: Optional[Tensor] = None + _random_samples: Optional[Tensor] = None, ) -> Tensor: if has_torch_function_variadic(input, _random_samples): return handle_torch_function( @@ -528,11 +546,12 @@ def _fractional_max_pool2d( def fractional_max_pool3d_with_indices( - input: Tensor, kernel_size: BroadcastingList3[int], + input: Tensor, + kernel_size: BroadcastingList3[int], output_size: Optional[BroadcastingList3[int]] = None, output_ratio: Optional[BroadcastingList3[float]] = None, return_indices: bool = False, - _random_samples: Optional[Tensor] = None + _random_samples: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) @@ -585,7 +604,9 @@ def fractional_max_pool3d_with_indices( _random_samples=_random_samples, ) if output_size is None and output_ratio is None: - raise ValueError("fractional_max_pool3d requires specifying either an output_size or an output_ratio") + raise ValueError( + "fractional_max_pool3d requires specifying either an output_size or an output_ratio" + ) if output_size is None: assert output_ratio is not None _output_ratio = _triple(output_ratio) @@ -597,16 +618,21 @@ def fractional_max_pool3d_with_indices( if _random_samples is None: n_batch = 1 if input.dim() == 4 else input.size(0) - _random_samples = torch.rand(n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device) - return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) + _random_samples = torch.rand( + n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device + ) + return torch._C._nn.fractional_max_pool3d( + input, kernel_size, output_size, _random_samples + ) def _fractional_max_pool3d( - input: Tensor, kernel_size: BroadcastingList3[int], + input: Tensor, + kernel_size: BroadcastingList3[int], output_size: Optional[BroadcastingList3[int]] = None, output_ratio: Optional[BroadcastingList3[float]] = None, return_indices: bool = False, - _random_samples: Optional[Tensor] = None + _random_samples: Optional[Tensor] = None, ) -> Tensor: if has_torch_function_variadic(input, _random_samples): return handle_torch_function( @@ -636,12 +662,13 @@ def _fractional_max_pool3d( def max_pool1d_with_indices( - input: Tensor, kernel_size: BroadcastingList1[int], + input: Tensor, + kernel_size: BroadcastingList1[int], stride: Optional[BroadcastingList1[int]] = None, padding: BroadcastingList1[int] = 0, dilation: BroadcastingList1[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -682,16 +709,19 @@ def max_pool1d_with_indices( ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return torch.max_pool1d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) def _max_pool1d( - input: Tensor, kernel_size: BroadcastingList1[int], + input: Tensor, + kernel_size: BroadcastingList1[int], stride: Optional[BroadcastingList1[int]] = None, padding: BroadcastingList1[int] = 0, dilation: BroadcastingList1[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( @@ -722,12 +752,13 @@ def _max_pool1d( def max_pool2d_with_indices( - input: Tensor, kernel_size: BroadcastingList2[int], + input: Tensor, + kernel_size: BroadcastingList2[int], stride: Optional[BroadcastingList2[int]] = None, padding: BroadcastingList2[int] = 0, dilation: BroadcastingList2[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -768,16 +799,19 @@ def max_pool2d_with_indices( ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return torch._C._nn.max_pool2d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) def _max_pool2d( - input: Tensor, kernel_size: BroadcastingList2[int], + input: Tensor, + kernel_size: BroadcastingList2[int], stride: Optional[BroadcastingList2[int]] = None, padding: BroadcastingList2[int] = 0, dilation: BroadcastingList2[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( @@ -808,12 +842,13 @@ def _max_pool2d( def max_pool3d_with_indices( - input: Tensor, kernel_size: BroadcastingList3[int], + input: Tensor, + kernel_size: BroadcastingList3[int], stride: Optional[BroadcastingList3[int]] = None, padding: BroadcastingList3[int] = 0, dilation: BroadcastingList3[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -854,16 +889,19 @@ def max_pool3d_with_indices( ) if stride is None: stride = torch.jit.annotate(List[int], []) - return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return torch._C._nn.max_pool3d_with_indices( + input, kernel_size, stride, padding, dilation, ceil_mode + ) def _max_pool3d( - input: Tensor, kernel_size: BroadcastingList3[int], + input: Tensor, + kernel_size: BroadcastingList3[int], stride: Optional[BroadcastingList3[int]] = None, padding: BroadcastingList3[int] = 0, dilation: BroadcastingList3[int] = 1, ceil_mode: bool = False, - return_indices: bool = False + return_indices: bool = False, ) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( @@ -894,12 +932,20 @@ def _max_pool3d( def _unpool_output_size( - input: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], output_size: Optional[List[int]] + input: Tensor, + kernel_size: List[int], + stride: List[int], + padding: List[int], + output_size: Optional[List[int]], ) -> List[int]: input_size = input.size() default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): - default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) + default_size.append( + (input_size[-len(kernel_size) + d] - 1) * stride[d] + + kernel_size[d] + - 2 * padding[d] + ) if output_size is None: ret = default_size else: @@ -923,11 +969,12 @@ def _unpool_output_size( def max_unpool1d( - input: Tensor, indices: Tensor, + input: Tensor, + indices: Tensor, kernel_size: BroadcastingList1[int], stride: Optional[BroadcastingList1[int]] = None, padding: BroadcastingList1[int] = 0, - output_size: Optional[BroadcastingList1[int]] = None + output_size: Optional[BroadcastingList1[int]] = None, ) -> Tensor: r"""Compute a partial inverse of :class:`MaxPool1d`. @@ -955,15 +1002,18 @@ def max_unpool1d( output_size = output_size + [1] else: output_size = output_size + (1,) - return torch._C._nn.max_unpool2d(input.unsqueeze(-1), indices.unsqueeze(-1), output_size).squeeze(-1) + return torch._C._nn.max_unpool2d( + input.unsqueeze(-1), indices.unsqueeze(-1), output_size + ).squeeze(-1) def max_unpool2d( - input: Tensor, indices: Tensor, + input: Tensor, + indices: Tensor, kernel_size: BroadcastingList2[int], stride: Optional[BroadcastingList2[int]] = None, padding: BroadcastingList2[int] = 0, - output_size: Optional[BroadcastingList2[int]] = None + output_size: Optional[BroadcastingList2[int]] = None, ) -> Tensor: r"""Compute a partial inverse of :class:`MaxPool2d`. @@ -991,11 +1041,12 @@ def max_unpool2d( def max_unpool3d( - input: Tensor, indices: Tensor, + input: Tensor, + indices: Tensor, kernel_size: BroadcastingList3[int], stride: Optional[BroadcastingList3[int]] = None, padding: BroadcastingList3[int] = 0, - output_size: Optional[BroadcastingList3[int]] = None + output_size: Optional[BroadcastingList3[int]] = None, ) -> Tensor: r"""Compute a partial inverse of :class:`MaxPool3d`. @@ -1023,10 +1074,11 @@ def max_unpool3d( def lp_pool3d( - input: Tensor, norm_type: Union[int, float], + input: Tensor, + norm_type: Union[int, float], kernel_size: BroadcastingList3[int], stride: Optional[BroadcastingList3[int]] = None, - ceil_mode: bool = False + ceil_mode: bool = False, ) -> Tensor: r""" Apply a 3D power-average pooling over an input signal composed of several input planes. @@ -1038,22 +1090,33 @@ def lp_pool3d( """ if has_torch_function_unary(input): return handle_torch_function( - lp_pool3d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + lp_pool3d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, ) - kd, kw, kh = utils._triple(kernel_size) + kd, kw, kh = _triple(kernel_size) if stride is not None: out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: - out = avg_pool3d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + out = avg_pool3d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) - return (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) + ) def lp_pool2d( - input: Tensor, norm_type: Union[int, float], + input: Tensor, + norm_type: Union[int, float], kernel_size: BroadcastingList2[int], stride: Optional[BroadcastingList2[int]] = None, - ceil_mode: bool = False + ceil_mode: bool = False, ) -> Tensor: r""" Apply a 2D power-average pooling over an input signal composed of several input planes. @@ -1065,22 +1128,31 @@ def lp_pool2d( """ if has_torch_function_unary(input): return handle_torch_function( - lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + lp_pool2d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, ) - kw, kh = utils._pair(kernel_size) + kw, kh = _pair(kernel_size) if stride is not None: out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: - out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + out = avg_pool2d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) def lp_pool1d( - input: Tensor, norm_type: Union[int, float], + input: Tensor, + norm_type: Union[int, float], kernel_size: int, stride: Optional[BroadcastingList1[int]] = None, - ceil_mode: bool = False + ceil_mode: bool = False, ) -> Tensor: r"""Apply a 1D power-average pooling over an input signal composed of several input planes. @@ -1091,18 +1163,30 @@ def lp_pool1d( """ if has_torch_function_unary(input): return handle_torch_function( - lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + lp_pool1d, + (input,), + input, + norm_type, + kernel_size, + stride=stride, + ceil_mode=ceil_mode, ) if stride is not None: out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) else: - out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + out = avg_pool1d( + input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode + ) - return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + return ( + (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + ) def adaptive_max_pool1d_with_indices( - input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" adaptive_max_pool1d(input, output_size, return_indices=False) @@ -1118,15 +1202,27 @@ def adaptive_max_pool1d_with_indices( """ if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool1d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, ) return torch.adaptive_max_pool1d(input, output_size) -def _adaptive_max_pool1d(input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False) -> Tensor: +def _adaptive_max_pool1d( + input: Tensor, + output_size: BroadcastingList1[int], + return_indices: bool = False, +) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool1d, + (input,), + input, + output_size, + return_indices=return_indices, ) return adaptive_max_pool1d_with_indices(input, output_size)[0] @@ -1143,8 +1239,9 @@ def _adaptive_max_pool1d(input: Tensor, output_size: BroadcastingList1[int], ret def adaptive_max_pool2d_with_indices( - input: Tensor, output_size: BroadcastingList2[int], - return_indices: bool = False + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r"""adaptive_max_pool2d(input, output_size, return_indices=False) @@ -1160,16 +1257,28 @@ def adaptive_max_pool2d_with_indices( """ if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool2d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) -def _adaptive_max_pool2d(input: Tensor, output_size: BroadcastingList2[int], return_indices: bool = False) -> Tensor: +def _adaptive_max_pool2d( + input: Tensor, + output_size: BroadcastingList2[int], + return_indices: bool = False, +) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool2d, + (input,), + input, + output_size, + return_indices=return_indices, ) return adaptive_max_pool2d_with_indices(input, output_size)[0] @@ -1186,8 +1295,9 @@ def _adaptive_max_pool2d(input: Tensor, output_size: BroadcastingList2[int], ret def adaptive_max_pool3d_with_indices( - input: Tensor, output_size: BroadcastingList3[int], - return_indices: bool = False + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, ) -> Tuple[Tensor, Tensor]: # noqa: D400 r""" adaptive_max_pool3d(input, output_size, return_indices=False) @@ -1204,16 +1314,28 @@ def adaptive_max_pool3d_with_indices( """ if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool3d_with_indices, + (input,), + input, + output_size, + return_indices=return_indices, ) output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) -def _adaptive_max_pool3d(input: Tensor, output_size: BroadcastingList3[int], return_indices: bool = False) -> Tensor: +def _adaptive_max_pool3d( + input: Tensor, + output_size: BroadcastingList3[int], + return_indices: bool = False, +) -> Tensor: if has_torch_function_unary(input): return handle_torch_function( - adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + adaptive_max_pool3d, + (input,), + input, + output_size, + return_indices=return_indices, ) return adaptive_max_pool3d_with_indices(input, output_size)[0] @@ -1276,7 +1398,12 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T # Activation functions -def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: +def dropout( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. Uses samples from a Bernoulli distribution. @@ -1289,25 +1416,45 @@ def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ if has_torch_function_unary(input): - return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + return handle_torch_function( + dropout, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") - return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + return ( + _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + ) -def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: +def alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: r"""Apply alpha dropout to the input. See :class:`~torch.nn.AlphaDropout` for details. """ if has_torch_function_unary(input): - return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + return handle_torch_function( + alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") - return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) + return ( + _VF.alpha_dropout_(input, p, training) + if inplace + else _VF.alpha_dropout(input, p, training) + ) -def dropout1d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: +def dropout1d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: r"""Randomly zero out entire channels (a channel is a 1D feature map). For example, the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1323,21 +1470,29 @@ def dropout1d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ if has_torch_function_unary(input): - return handle_torch_function(dropout1d, (input,), input, p=p, training=training, inplace=inplace) + return handle_torch_function( + dropout1d, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") inp_dim = input.dim() if inp_dim not in (2, 3): - raise RuntimeError(f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " - "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " - "spatial dimension, a channel dimension, and an optional batch dimension " - "(i.e. 2D or 3D inputs).") + raise RuntimeError( + f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " + "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " + "spatial dimension, a channel dimension, and an optional batch dimension " + "(i.e. 2D or 3D inputs)." + ) is_batched = inp_dim == 3 if not is_batched: input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) - result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) if not is_batched: result = result.squeeze_(0) if inplace else result.squeeze(0) @@ -1345,7 +1500,12 @@ def dropout1d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo return result -def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: +def dropout2d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: r"""Randomly zero out entire channels (a channel is a 2D feature map). For example, the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1361,16 +1521,20 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ if has_torch_function_unary(input): - return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + return handle_torch_function( + dropout2d, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") inp_dim = input.dim() if inp_dim not in (3, 4): - warn_msg = (f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " - "and will result in an error in a future release. To retain the behavior " - "and silence this warning, please use dropout instead. Note that dropout2d " - "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " - "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).") + warn_msg = ( + f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout2d " + "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs)." + ) warnings.warn(warn_msg) # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing @@ -1378,18 +1542,29 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo # behavior is maintained here for now. # See https://github.com/pytorch/pytorch/issues/77081 if inp_dim == 3: - warnings.warn("dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " - "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " - "is the channel dim. This behavior will change in a future release to interpret the " - "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " - "channel-wise dropout behavior, please switch to using dropout1d instead.") + warnings.warn( + "dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead." + ) - result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) return result -def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: +def dropout3d( + input: Tensor, + p: float = 0.5, + training: bool = True, + inplace: bool = False, +) -> Tensor: r"""Randomly zero out entire channels (a channel is a 3D feature map). For example, the :math:`j`-th channel of the :math:`i`-th sample in the @@ -1405,30 +1580,43 @@ def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo inplace: If set to ``True``, will do this operation in-place. Default: ``False`` """ if has_torch_function_unary(input): - return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + return handle_torch_function( + dropout3d, (input,), input, p=p, training=training, inplace=inplace + ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") inp_dim = input.dim() if inp_dim not in (4, 5): - warn_msg = (f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " - "and will result in an error in a future release. To retain the behavior " - "and silence this warning, please use dropout instead. Note that dropout3d " - "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " - "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs).") + warn_msg = ( + f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout3d " + "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs)." + ) warnings.warn(warn_msg) is_batched = inp_dim == 5 if not is_batched: input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) - result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + result = ( + _VF.feature_dropout_(input, p, training) + if inplace + else _VF.feature_dropout(input, p, training) + ) if not is_batched: result = result.squeeze_(0) if inplace else result.squeeze(0) return result -def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: +def feature_alpha_dropout( + input: Tensor, + p: float = 0.5, + training: bool = False, + inplace: bool = False, +) -> Tensor: r"""Randomly masks out entire channels (a channel is a feature map). For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input @@ -1450,20 +1638,36 @@ def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, """ if has_torch_function_unary(input): return handle_torch_function( - feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + feature_alpha_dropout, + (input,), + input, + p=p, + training=training, + inplace=inplace, ) if p < 0.0 or p > 1.0: raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") - return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) + return ( + _VF.feature_alpha_dropout_(input, p, training) + if inplace + else _VF.feature_alpha_dropout(input, p, training) + ) -def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: +def _threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = False, +) -> Tensor: r"""Apply a threshold to each element of the input Tensor. See :class:`~torch.nn.Threshold` for more details. """ if has_torch_function_unary(input): - return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) + return handle_torch_function( + _threshold, (input,), input, threshold, value, inplace=inplace + ) if inplace: result = _VF.threshold_(input, threshold, value) else: @@ -1532,11 +1736,18 @@ def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 if has_torch_function_unary(input): return handle_torch_function(glu, (input,), input, dim=dim) if input.dim() == 0: - raise RuntimeError("glu does not support scalars because halving size must be even") + raise RuntimeError( + "glu does not support scalars because halving size must be even" + ) return torch._C._nn.glu(input, dim) -def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: # noqa: D400,D402 +def hardtanh( + input: Tensor, + min_val: float = -1.0, + max_val: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor @@ -1544,7 +1755,9 @@ def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: details. """ if has_torch_function_unary(input): - return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) + return handle_torch_function( + hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace + ) if min_val > max_val: raise ValueError("min_val cannot be greater than max_val") if inplace: @@ -1633,7 +1846,11 @@ def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 ) -def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: # noqa: D400,D402 +def celu( + input: Tensor, + alpha: float = 1.0, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 r"""celu(input, alpha=1., inplace=False) -> Tensor Applies element-wise, @@ -1642,7 +1859,9 @@ def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: # See :class:`~torch.nn.CELU` for more details. """ if has_torch_function_unary(input): - return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) + return handle_torch_function( + celu, (input,), input, alpha=alpha, inplace=inplace + ) if inplace: result = torch.celu_(input, alpha) else: @@ -1660,7 +1879,11 @@ def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: # ) -def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: # noqa: D400,D402 +def leaky_relu( + input: Tensor, + negative_slope: float = 0.01, + inplace: bool = False, +) -> Tensor: # noqa: D400,D402 r""" leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor @@ -1670,7 +1893,9 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals See :class:`~torch.nn.LeakyReLU` for more details. """ if has_torch_function_unary(input): - return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) + return handle_torch_function( + leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace + ) if inplace: result = torch._C._nn.leaky_relu_(input, negative_slope) else: @@ -1705,11 +1930,16 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals :ref:`broadcasting semantics`. See :class:`~torch.nn.PReLU` for more details. -""") +""", +) def rrelu( - input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False + input: Tensor, + lower: float = 1.0 / 8, + upper: float = 1.0 / 3, + training: bool = False, + inplace: bool = False, ) -> Tensor: # noqa: D400,D402 r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor @@ -1719,7 +1949,13 @@ def rrelu( """ if has_torch_function_unary(input): return handle_torch_function( - rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + rrelu, + (input,), + input, + lower=lower, + upper=upper, + training=training, + inplace=inplace, ) if inplace: result = torch.rrelu_(input, lower, upper, training) @@ -1764,7 +2000,8 @@ def rrelu( \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) See `Gaussian Error Linear Units (GELUs) `_. -""") +""", +) hardshrink = _add_docstr( torch.hardshrink, @@ -1774,7 +2011,8 @@ def rrelu( Applies the hard shrinkage function element-wise See :class:`~torch.nn.Hardshrink` for more details. -""") +""", +) def tanhshrink(input): # noqa: D400,D402 @@ -1829,7 +2067,12 @@ def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: return ret -def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: +def softmin( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: r"""Apply a softmin function. Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. @@ -1845,7 +2088,9 @@ def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp is performed. This is useful for preventing data type overflows. Default: None. """ if has_torch_function_unary(input): - return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function( + softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) if dim is None: dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) if dtype is None: @@ -1855,7 +2100,12 @@ def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp return ret -def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: +def softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: r"""Apply a softmax function. Softmax is defined as: @@ -1881,7 +2131,9 @@ def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp """ if has_torch_function_unary(input): - return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function( + softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) if dim is None: dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is None: @@ -1891,7 +2143,13 @@ def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtyp return ret -def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: +def gumbel_softmax( + logits: Tensor, + tau: float = 1, + hard: bool = False, + eps: float = 1e-10, + dim: int = -1, +) -> Tensor: r""" Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. @@ -1932,12 +2190,16 @@ def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: floa https://arxiv.org/abs/1611.01144 """ if has_torch_function_unary(logits): - return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + return handle_torch_function( + gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim + ) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") gumbels = ( - -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) + .exponential_() + .log() ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) @@ -1945,7 +2207,9 @@ def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: floa if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + y_hard = torch.zeros_like( + logits, memory_format=torch.legacy_contiguous_format + ).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. @@ -1953,7 +2217,12 @@ def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: floa return ret -def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: +def log_softmax( + input: Tensor, + dim: Optional[int] = None, + _stacklevel: int = 3, + dtype: Optional[DType] = None, +) -> Tensor: r"""Apply a softmax followed by a logarithm. While mathematically equivalent to log(softmax(x)), doing these two @@ -1970,7 +2239,9 @@ def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, is performed. This is useful for preventing data type overflows. Default: None. """ if has_torch_function_unary(input): - return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + return handle_torch_function( + log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype + ) if dim is None: dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is None: @@ -2055,7 +2326,10 @@ def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` - Bias: :math:`(out\_features)` or :math:`()` - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight -""".format(**sparse_support_notes)) +""".format( + **sparse_support_notes + ), +) bilinear = _add_docstr( @@ -2077,7 +2351,8 @@ def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: - bias: :math:`(\text{out\_features})` - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` and all but the last dimension are the same shape as the input. -""") +""", +) def silu(input: Tensor, inplace: bool = False) -> Tensor: @@ -2150,7 +2425,12 @@ def hardswish(input: Tensor, inplace: bool = False) -> Tensor: return torch._C._nn.hardswish(input) -def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tuple[Tensor, Tensor]: +def _no_grad_embedding_renorm_( + weight: Tensor, + input: Tensor, + max_norm: float, + norm_type: float, +) -> Tuple[Tensor, Tensor]: torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) @@ -2246,9 +2526,13 @@ def embedding( ) if padding_idx is not None: if padding_idx > 0: - assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" + assert padding_idx < weight.size( + 0 + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" + assert padding_idx >= -weight.size( + 0 + ), "Padding_idx must be within num_embeddings" padding_idx = weight.size(0) + padding_idx else: padding_idx = -1 @@ -2409,7 +2693,9 @@ def embedding_bag( " fixed length sequences. However, found " f"offsets of type {type_str}" ) - offsets = torch.arange(0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) + offsets = torch.arange( + 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device + ) input = input.reshape(-1) if per_sample_weights is not None: @@ -2420,7 +2706,9 @@ def embedding_bag( if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") else: - raise ValueError(f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}") + raise ValueError( + f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}" + ) if mode == "sum": mode_enum = 0 elif mode == "mean": @@ -2429,7 +2717,9 @@ def embedding_bag( mode_enum = 2 if scale_grad_by_freq: - raise ValueError("max mode does not support scaling the gradient by the frequency") + raise ValueError( + "max mode does not support scaling the gradient by the frequency" + ) if sparse: raise ValueError("max mode does not support sparse weights") @@ -2452,7 +2742,15 @@ def embedding_bag( ) ret, _, _, _ = torch.embedding_bag( - weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset, padding_idx + weight, + input, + offsets, + scale_grad_by_freq, + mode_enum, + sparse, + per_sample_weights, + include_last_offset, + padding_idx, ) return ret @@ -2475,7 +2773,9 @@ def _verify_batch_size(size: List[int]) -> None: for i in range(len(size) - 2): size_prods *= size[i + 2] if size_prods == 1: - raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}") + raise ValueError( + f"Expected more than 1 value per channel when training, got input size {size}" + ) def batch_norm( @@ -2510,7 +2810,15 @@ def batch_norm( _verify_batch_size(input.size()) return torch.batch_norm( - input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + torch.backends.cudnn.enabled, ) @@ -2520,7 +2828,9 @@ def _verify_spatial_size(size: List[int]) -> None: for i in range(2, len(size)): size_prods *= size[i] if size_prods == 1: - raise ValueError(f"Expected more than 1 spatial element when training, got input size {size}") + raise ValueError( + f"Expected more than 1 spatial element when training, got input size {size}" + ) def instance_norm( @@ -2554,7 +2864,15 @@ def instance_norm( if use_input_stats: _verify_spatial_size(input.size()) return torch.instance_norm( - input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled + input, + weight, + bias, + running_mean, + running_var, + use_input_stats, + momentum, + eps, + torch.backends.cudnn.enabled, ) @@ -2571,9 +2889,18 @@ def layer_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function( - layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps + layer_norm, + (input, weight, bias), + input, + normalized_shape, + weight=weight, + bias=bias, + eps=eps, ) - return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) + return torch.layer_norm( + input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled + ) + def rms_norm( input: Tensor, @@ -2591,22 +2918,52 @@ def rms_norm( ) return torch.rms_norm(input, normalized_shape, weight, eps) + def group_norm( - input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, ) -> Tensor: r"""Apply Group Normalization for last certain number of dimensions. See :class:`~torch.nn.GroupNorm` for details. """ if has_torch_function_variadic(input, weight, bias): - return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) + return handle_torch_function( + group_norm, + ( + input, + weight, + bias, + ), + input, + num_groups, + weight=weight, + bias=bias, + eps=eps, + ) if input.dim() < 2: - raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") - _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) - return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) + raise RuntimeError( + f"Expected at least 2 dimensions for input tensor but received {input.dim()}" + ) + _verify_batch_size( + [input.size(0) * input.size(1) // num_groups, num_groups] + + list(input.size()[2:]) + ) + return torch.group_norm( + input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled + ) -def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Tensor: +def local_response_norm( + input: Tensor, + size: int, + alpha: float = 1e-4, + beta: float = 0.75, + k: float = 1.0, +) -> Tensor: r"""Apply local response normalization over an input signal. The input signal is composed of several input planes, where channels occupy the second dimension. @@ -2615,7 +2972,9 @@ def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: flo See :class:`~torch.nn.LocalResponseNorm` for details. """ if has_torch_function_unary(input): - return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + return handle_torch_function( + local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k + ) dim = input.dim() if dim < 3: raise ValueError( @@ -2699,11 +3058,22 @@ def ctc_loss( return handle_torch_function( ctc_loss, (log_probs, targets, input_lengths, target_lengths), - log_probs, targets, input_lengths, target_lengths, - blank=blank, reduction=reduction, zero_infinity=zero_infinity + log_probs, + targets, + input_lengths, + target_lengths, + blank=blank, + reduction=reduction, + zero_infinity=zero_infinity, ) return torch.ctc_loss( - log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity + log_probs, + targets, + input_lengths, + target_lengths, + blank, + _Reduction.get_enum(reduction), + zero_infinity, ) @@ -2775,7 +3145,9 @@ def nll_loss( ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + return torch._C._nn.nll_loss_nd( + input, target, weight, _Reduction.get_enum(reduction), ignore_index + ) def poisson_nll_loss( @@ -2839,7 +3211,9 @@ def poisson_nll_loss( ret = input raise ValueError(reduction + " is not a valid value for reduction") - ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction)) + ret = torch.poisson_nll_loss( + input, target, log_input, full, eps, _Reduction.get_enum(reduction) + ) return ret @@ -2884,7 +3258,6 @@ def gaussian_nll_loss( # If var.size == input.size, the case is heteroscedastic and no further checks are needed. # Otherwise: if var.size() != input.size(): - # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. # e.g. input.size = (10, 2, 3), var.size = (10, 2) # -> unsqueeze var so that var.shape = (10, 2, 1) @@ -2895,7 +3268,9 @@ def gaussian_nll_loss( # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. # This is also a homoscedastic case. # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) - elif input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1: # Heteroscedastic case + elif ( + input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1 + ): # Heteroscedastic case pass # If none of the above pass, then the size of var is incorrect. @@ -2903,7 +3278,7 @@ def gaussian_nll_loss( raise ValueError("var is of incorrect size") # Check validity of reduction mode - if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + if reduction != "none" and reduction != "mean" and reduction != "sum": raise ValueError(reduction + " is not valid") # Entries of var must be non-negative @@ -2916,13 +3291,13 @@ def gaussian_nll_loss( var.clamp_(min=eps) # Calculate the loss - loss = 0.5 * (torch.log(var) + (input - target)**2 / var) + loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) if full: loss += 0.5 * math.log(2 * math.pi) - if reduction == 'mean': + if reduction == "mean": return loss.mean() - elif reduction == 'sum': + elif reduction == "sum": return loss.sum() else: return loss @@ -3101,7 +3476,14 @@ def cross_entropy( ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) + return torch._C._nn.cross_entropy_loss( + input, + target, + weight, + _Reduction.get_enum(reduction), + ignore_index, + label_smoothing, + ) def binary_cross_entropy( @@ -3239,9 +3621,13 @@ def binary_cross_entropy_with_logits( reduction_enum = _Reduction.get_enum(reduction) if not (target.size() == input.size()): - raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})") + raise ValueError( + f"Target size ({target.size()}) must be the same as input size ({input.size()})" + ) - return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) + return torch.binary_cross_entropy_with_logits( + input, target, weight, pos_weight, reduction_enum + ) def smooth_l1_loss( @@ -3283,15 +3669,19 @@ def smooth_l1_loss( expanded_input, expanded_target = torch.broadcast_tensors(input, target) if beta == 0.0: - return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) else: - return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) + return torch._C._nn.smooth_l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), beta + ) def huber_loss( input: Tensor, target: Tensor, - reduction: str = 'mean', + reduction: str = "mean", delta: float = 1.0, ) -> Tensor: r"""Compute the Huber loss. @@ -3314,13 +3704,17 @@ def huber_loss( delta=delta, ) if not (target.size() == input.size()): - warnings.warn(f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " - "This will likely lead to incorrect results due to broadcasting. " - "Please ensure they have the same size.", - stacklevel=2) + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) expanded_input, expanded_target = torch.broadcast_tensors(input, target) - return torch._C._nn.huber_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), delta) + return torch._C._nn.huber_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction), delta + ) def l1_loss( @@ -3338,7 +3732,13 @@ def l1_loss( """ if has_torch_function_variadic(input, target): return handle_torch_function( - l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, ) if not (target.size() == input.size()): warnings.warn( @@ -3351,7 +3751,9 @@ def l1_loss( reduction = _Reduction.legacy_get_string(size_average, reduce) expanded_input, expanded_target = torch.broadcast_tensors(input, target) - return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + return torch._C._nn.l1_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) def mse_loss( @@ -3368,7 +3770,13 @@ def mse_loss( """ if has_torch_function_variadic(input, target): return handle_torch_function( - mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + mse_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, ) if not (target.size() == input.size()): warnings.warn( @@ -3381,7 +3789,9 @@ def mse_loss( reduction = _Reduction.legacy_get_string(size_average, reduce) expanded_input, expanded_target = torch.broadcast_tensors(input, target) - return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + return torch._C._nn.mse_loss( + expanded_input, expanded_target, _Reduction.get_enum(reduction) + ) def margin_ranking_loss( @@ -3413,7 +3823,7 @@ def margin_ranking_loss( reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - if (input1.dim() != input2.dim() or input1.dim() != target.dim()): + if input1.dim() != input2.dim() or input1.dim() != target.dim(): raise RuntimeError( f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " @@ -3493,7 +3903,13 @@ def soft_margin_loss( """ if has_torch_function_variadic(input, target): return handle_torch_function( - soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + soft_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, ) if size_average is not None or reduce is not None: reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) @@ -3618,7 +4034,9 @@ def multi_margin_loss( if weight.dim() != 1: raise ValueError("weight must be one-dimensional") - return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) + return torch._C._nn.multi_margin_loss( + input, target, p, margin, weight, reduction_enum + ) pixel_shuffle = _add_docstr( @@ -3755,17 +4173,36 @@ def multi_margin_loss( """, ) -@_overload # noqa: F811 -def upsample(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950 + +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 pass -@_overload # noqa: F811 -def upsample(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950 +@_overload +def upsample( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, +) -> Tensor: # noqa: B950 pass -def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 +def upsample( # noqa: F811 + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, +): r"""Upsample input. Provided tensor is upsampled to either the given :attr:`size` or the given @@ -3848,22 +4285,46 @@ def _is_integer(x) -> bool: return isinstance(x, Tensor) and not x.is_floating_point() -@_overload # noqa: F811 -def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 pass -@_overload # noqa: F811 -def interpolate(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 pass -@_overload # noqa: F811 -def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 +@_overload +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 pass -@_overload # noqa: F811 +@_overload def interpolate( # noqa: F811 input: Tensor, size: Optional[List[int]] = None, @@ -3872,10 +4333,19 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, -) -> Tensor: # noqa: F811 +) -> Tensor: pass -def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 + +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: B950 r"""Down/up samples the input. Tensor interpolated to either the given :attr:`size` or the given @@ -3954,7 +4424,7 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, - antialias=antialias + antialias=antialias, ) if mode in ("nearest", "area", "nearest-exact"): @@ -4013,8 +4483,14 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona else: raise ValueError("either size or scale_factor should be defined") - if recompute_scale_factor is not None and recompute_scale_factor and size is not None: - raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") + if ( + recompute_scale_factor is not None + and recompute_scale_factor + and size is not None + ): + raise ValueError( + "recompute_scale_factor is not meaningful with an explicit size." + ) # "area" mode always requires an explicit size rather than scale factor. # Re-use the recompute_scale_factor code path. @@ -4028,21 +4504,31 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona if not torch.jit.is_scripting() and torch._C._get_tracing_state(): # make scale_factor a tensor in tracing so constant doesn't get baked in output_size = [ - (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) + ( + torch.floor( + ( + input.size(i + 2).float() + * torch.tensor(scale_factors[i], dtype=torch.float32) + ).float() + ) + ) for i in range(dim) ] elif torch.jit.is_scripting(): - output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) - for i in range(dim)] - else: output_size = [ - _sym_int(input.size(i + 2) * scale_factors[i]) + int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim) ] + else: + output_size = [ + _sym_int(input.size(i + 2) * scale_factors[i]) for i in range(dim) + ] scale_factors = None if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): - raise ValueError("Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input") + raise ValueError( + "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" + ) if input.dim() == 3 and mode == "nearest": return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) @@ -4070,11 +4556,15 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona if input.dim() == 3 and mode == "linear": assert align_corners is not None - return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_linear1d( + input, output_size, align_corners, scale_factors + ) if input.dim() == 4 and mode == "bilinear": assert align_corners is not None if antialias: - return torch._C._nn._upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors) + return torch._C._nn._upsample_bilinear2d_aa( + input, output_size, align_corners, scale_factors + ) # Two levels are necessary to prevent TorchScript from touching # are_deterministic_algorithms_enabled. if not torch.jit.is_scripting(): @@ -4082,17 +4572,26 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona # Use slow decomp whose backward will be in terms of index_put # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) - return importlib.import_module('torch._decomp.decompositions')._upsample_linear_vec( - input, output_size, align_corners, scale_factors) - return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) + return importlib.import_module( + "torch._decomp.decompositions" + )._upsample_linear_vec(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_bilinear2d( + input, output_size, align_corners, scale_factors + ) if input.dim() == 5 and mode == "trilinear": assert align_corners is not None - return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_trilinear3d( + input, output_size, align_corners, scale_factors + ) if input.dim() == 4 and mode == "bicubic": assert align_corners is not None if antialias: - return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors) - return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) + return torch._C._nn._upsample_bicubic2d_aa( + input, output_size, align_corners, scale_factors + ) + return torch._C._nn.upsample_bicubic2d( + input, output_size, align_corners, scale_factors + ) if input.dim() == 3 and mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") @@ -4118,13 +4617,21 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) -@_overload # noqa: F811 -def upsample_nearest(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None) -> Tensor: # noqa: F811 +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: pass -@_overload # noqa: F811 -def upsample_nearest(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None) -> Tensor: # noqa: F811 +@_overload +def upsample_nearest( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: pass @@ -4160,31 +4667,39 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) -@_overload # noqa: F811 -def upsample_bilinear( - input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None -) -> Tensor: # noqa: F811 +@_overload +def upsample_bilinear( # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[float] = None, +) -> Tensor: pass -@_overload # noqa: F811 +@_overload def upsample_bilinear( # noqa: F811 - input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None -) -> Tensor: # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, +) -> Tensor: pass -@_overload # noqa: F811 +@_overload def upsample_bilinear( # noqa: F811 - input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None -) -> Tensor: # noqa: F811 + input: Tensor, + size: Optional[int] = None, + scale_factor: Optional[List[float]] = None, +) -> Tensor: pass -@_overload # noqa: F811 +@_overload def upsample_bilinear( # noqa: F811 - input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None -) -> Tensor: # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[List[float]] = None, +) -> Tensor: pass @@ -4217,7 +4732,9 @@ def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 if upsample_bilinear.__doc__: - upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format( + **reproducibility_notes + ) GRID_SAMPLE_INTERPOLATION_MODES = { "bilinear": 0, @@ -4342,13 +4859,23 @@ def grid_sample( """ if has_torch_function_variadic(input, grid): return handle_torch_function( - grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + grid_sample, + (input, grid), + input, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, ) if mode != "bilinear" and mode != "nearest" and mode != "bicubic": raise ValueError( f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" ) - if padding_mode != "zeros" and padding_mode != "border" and padding_mode != "reflection": + if ( + padding_mode != "zeros" + and padding_mode != "border" + and padding_mode != "reflection" + ): raise ValueError( "nn.functional.grid_sample(): expected padding_mode " "to be 'zeros', 'border', or 'reflection', " @@ -4381,7 +4908,11 @@ def grid_sample( return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) -def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = None) -> Tensor: +def affine_grid( + theta: Tensor, + size: List[int], + align_corners: Optional[bool] = None, +) -> Tensor: r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. .. note:: @@ -4429,7 +4960,9 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = (the center of the input image). """ if has_torch_function_unary(theta): - return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) + return handle_torch_function( + affine_grid, (theta,), theta, size, align_corners=align_corners + ) if align_corners is None: warnings.warn( "Default grid_sample and affine_grid behavior has changed " @@ -4441,7 +4974,9 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = # enforce floating point dtype on theta if not theta.is_floating_point(): - raise ValueError(f"Expected theta to have floating point type, but got {theta.dtype}") + raise ValueError( + f"Expected theta to have floating point type, but got {theta.dtype}" + ) # check that shapes and sizes match if len(size) == 4: if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: @@ -4475,82 +5010,88 @@ def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = return torch.affine_grid_generator(theta, size, align_corners) -def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor: +def pad( + input: Tensor, + pad: List[int], + mode: str = "constant", + value: Optional[float] = None, +) -> Tensor: r""" -pad(input, pad, mode="constant", value=None) -> Tensor - -Pads tensor. - -Padding size: - The padding size by which to pad some dimensions of :attr:`input` - are described starting from the last dimension and moving forward. - :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions - of ``input`` will be padded. - For example, to pad only the last dimension of the input tensor, then - :attr:`pad` has the form - :math:`(\text{padding\_left}, \text{padding\_right})`; - to pad the last 2 dimensions of the input tensor, then use - :math:`(\text{padding\_left}, \text{padding\_right},` - :math:`\text{padding\_top}, \text{padding\_bottom})`; - to pad the last 3 dimensions, use - :math:`(\text{padding\_left}, \text{padding\_right},` - :math:`\text{padding\_top}, \text{padding\_bottom}` - :math:`\text{padding\_front}, \text{padding\_back})`. - -Padding mode: - See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, - :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` - for concrete examples on how each of the padding modes works. Constant - padding is implemented for arbitrary dimensions. Circular, replicate and - reflection padding are implemented for padding the last 3 dimensions of a - 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, - or the last dimension of a 2D or 3D input tensor. + pad(input, pad, mode="constant", value=None) -> Tensor + + Pads tensor. + + Padding size: + The padding size by which to pad some dimensions of :attr:`input` + are described starting from the last dimension and moving forward. + :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions + of ``input`` will be padded. + For example, to pad only the last dimension of the input tensor, then + :attr:`pad` has the form + :math:`(\text{padding\_left}, \text{padding\_right})`; + to pad the last 2 dimensions of the input tensor, then use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom})`; + to pad the last 3 dimensions, use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom}` + :math:`\text{padding\_front}, \text{padding\_back})`. + + Padding mode: + See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, + :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` + for concrete examples on how each of the padding modes works. Constant + padding is implemented for arbitrary dimensions. Circular, replicate and + reflection padding are implemented for padding the last 3 dimensions of a + 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, + or the last dimension of a 2D or 3D input tensor. -Note: - When using the CUDA backend, this operation may induce nondeterministic - behaviour in its backward pass that is not easily switched off. - Please see the notes on :doc:`/notes/randomness` for background. + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. -Args: - input (Tensor): N-dimensional tensor - pad (tuple): m-elements tuple, where - :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. - mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. - Default: ``'constant'`` - value: fill value for ``'constant'`` padding. Default: ``0`` + Args: + input (Tensor): N-dimensional tensor + pad (tuple): m-elements tuple, where + :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. + mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + Default: ``'constant'`` + value: fill value for ``'constant'`` padding. Default: ``0`` -Examples:: + Examples:: - >>> t4d = torch.empty(3, 3, 4, 2) - >>> p1d = (1, 1) # pad last dim by 1 on each side - >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding - >>> print(out.size()) - torch.Size([3, 3, 4, 4]) - >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) - >>> out = F.pad(t4d, p2d, "constant", 0) - >>> print(out.size()) - torch.Size([3, 3, 8, 4]) - >>> t4d = torch.empty(3, 3, 4, 2) - >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) - >>> out = F.pad(t4d, p3d, "constant", 0) - >>> print(out.size()) - torch.Size([3, 9, 7, 3]) - -""" + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p1d = (1, 1) # pad last dim by 1 on each side + >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding + >>> print(out.size()) + torch.Size([3, 3, 4, 4]) + >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) + >>> out = F.pad(t4d, p2d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 3, 8, 4]) + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) + >>> out = F.pad(t4d, p3d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 9, 7, 3]) + """ if has_torch_function_unary(input): return handle_torch_function( - torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value) + torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value + ) if not torch.jit.is_scripting(): if torch.are_deterministic_algorithms_enabled() and input.is_cuda: - if mode == 'replicate': + if mode == "replicate": # Use slow decomp whose backward will be in terms of index_put. # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) - return importlib.import_module('torch._decomp.decompositions')._replication_pad( - input, pad - ) + return importlib.import_module( + "torch._decomp.decompositions" + )._replication_pad(input, pad) return torch._C._nn.pad(input, pad, mode, value) + # TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 pad.__module__ = "torch.nn.functional" @@ -4563,7 +5104,8 @@ def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[f pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor See :class:`torch.nn.PairwiseDistance` for details -""") +""", +) pdist = _add_docstr( @@ -4712,7 +5254,9 @@ def triplet_margin_loss( reduction_enum = _Reduction.get_enum(reduction) if margin <= 0: raise ValueError(f"margin must be greater than 0, got {margin}") - return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction_enum) + return torch.triplet_margin_loss( + anchor, positive, negative, margin, p, eps, swap, reduction_enum + ) def triplet_margin_with_distance_loss( @@ -4723,7 +5267,7 @@ def triplet_margin_with_distance_loss( distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, margin: float = 1.0, swap: bool = False, - reduction: str = "mean" + reduction: str = "mean", ) -> Tensor: r"""Compute the triplet margin loss for input tensors using a custom distance function. @@ -4764,7 +5308,8 @@ def triplet_margin_with_distance_loss( raise RuntimeError( f"The anchor, positive, and negative tensors are expected to have " f"the same number of dimensions, but got: anchor {a_dim}D, " - f"positive {p_dim}D, and negative {n_dim}D inputs") + f"positive {p_dim}D, and negative {n_dim}D inputs" + ) # Calculate loss if distance_function is None: @@ -4791,7 +5336,13 @@ def triplet_margin_with_distance_loss( return loss -def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: +def normalize( + input: Tensor, + p: float = 2.0, + dim: int = 1, + eps: float = 1e-12, + out: Optional[Tensor] = None, +) -> Tensor: r"""Perform :math:`L_p` normalization of inputs over specified dimension. For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each @@ -4811,7 +5362,9 @@ def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, o operation won't be differentiable. """ if has_torch_function_variadic(input, out): - return handle_torch_function(normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out) + return handle_torch_function( + normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out + ) if out is None: denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) return input / denom @@ -4825,10 +5378,11 @@ def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: def unfold( - input: Tensor, kernel_size: BroadcastingList2[int], + input: Tensor, + kernel_size: BroadcastingList2[int], dilation: BroadcastingList2[int] = 1, padding: BroadcastingList2[int] = 0, - stride: BroadcastingList2[int] = 1 + stride: BroadcastingList2[int] = 1, ) -> Tensor: r"""Extract sliding local blocks from a batched input tensor. @@ -4848,17 +5402,26 @@ def unfold( """ if has_torch_function_unary(input): return handle_torch_function( - unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + unfold, + (input,), + input, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, ) - return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) + return torch._C._nn.im2col( + input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) def fold( - input: Tensor, output_size: BroadcastingList2[int], + input: Tensor, + output_size: BroadcastingList2[int], kernel_size: BroadcastingList2[int], dilation: BroadcastingList2[int] = 1, padding: BroadcastingList2[int] = 0, - stride: BroadcastingList2[int] = 1 + stride: BroadcastingList2[int] = 1, ) -> Tensor: r"""Combine an array of sliding local blocks into a large containing tensor. @@ -4869,16 +5432,30 @@ def fold( """ if has_torch_function_unary(input): return handle_torch_function( - fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + fold, + (input,), + input, + output_size, + kernel_size, + dilation=dilation, + padding=padding, + stride=stride, ) return torch._C._nn.col2im( - input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + input, + _pair(output_size), + _pair(kernel_size), + _pair(dilation), + _pair(padding), + _pair(stride), ) + # # multihead attention # + def _in_projection_packed( q: Tensor, k: Tensor, @@ -4919,7 +5496,13 @@ def _in_projection_packed( # self-attention proj = linear(q, w, b) # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() - proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + proj = ( + proj.unflatten(-1, (3, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) return proj[0], proj[1], proj[2] else: # encoder-decoder attention @@ -4931,7 +5514,13 @@ def _in_projection_packed( q_proj = linear(q, w_q, b_q) kv_proj = linear(k, w_kv, b_kv) # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() - kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + kv_proj = ( + kv_proj.unflatten(-1, (2, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) return (q_proj, kv_proj[0], kv_proj[1]) else: w_q, w_k, w_v = w.chunk(3) @@ -4987,16 +5576,33 @@ def _in_projection( """ Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) - assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" - assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" - assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" - assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" - assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + assert w_q.shape == ( + Eq, + Eq, + ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == ( + Eq, + Ek, + ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == ( + Eq, + Ev, + ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == ( + Eq, + ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + Eq, + ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + Eq, + ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + scaled_dot_product_attention = _add_docstr( - torch._C._nn.scaled_dot_product_attention, r""" + torch._C._nn.scaled_dot_product_attention, + r""" scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor: Computes scaled dot product attention on query, key and value tensors, using @@ -5080,7 +5686,9 @@ def forward(self, ...): Note: {cudnn_reproducibility_note} -""".format(**reproducibility_notes) +""".format( + **reproducibility_notes + ) + r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. @@ -5124,10 +5732,18 @@ def forward(self, ...): .. _Memory-Efficient Attention: https://github.com/facebookresearch/xformers -""") +""", +) -def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, - key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int): + +def _mha_shape_check( + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + num_heads: int, +): # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` # and returns if the input is batched or not. # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. @@ -5136,58 +5752,67 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, if query.dim() == 3: # Batched Inputs is_batched = True - assert key.dim() == 3 and value.dim() == 3, \ - ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" - f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + assert key.dim() == 3 and value.dim() == 3, ( + "For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) if key_padding_mask is not None: - assert key_padding_mask.dim() == 2, \ - ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" - f" but found {key_padding_mask.dim()}-D tensor instead") + assert key_padding_mask.dim() == 2, ( + "For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) if attn_mask is not None: - assert attn_mask.dim() in (2, 3), \ - ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" - f" but found {attn_mask.dim()}-D tensor instead") + assert attn_mask.dim() in (2, 3), ( + "For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) elif query.dim() == 2: # Unbatched Inputs is_batched = False - assert key.dim() == 2 and value.dim() == 2, \ - ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" - f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + assert key.dim() == 2 and value.dim() == 2, ( + "For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively" + ) if key_padding_mask is not None: - assert key_padding_mask.dim() == 1, \ - ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" - f" but found {key_padding_mask.dim()}-D tensor instead") + assert key_padding_mask.dim() == 1, ( + "For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead" + ) if attn_mask is not None: - assert attn_mask.dim() in (2, 3), \ - ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" - f" but found {attn_mask.dim()}-D tensor instead") + assert attn_mask.dim() in (2, 3), ( + "For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead" + ) if attn_mask.dim() == 3: expected_shape = (num_heads, query.shape[0], key.shape[0]) - assert attn_mask.shape == expected_shape, \ - (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + assert ( + attn_mask.shape == expected_shape + ), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}" else: raise AssertionError( - f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor" + ) return is_batched + def _canonical_mask( - mask: Optional[Tensor], - mask_name: str, - other_type: Optional[DType], - other_name: str, - target_type: DType, - check_other: bool = True, + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, ) -> Optional[Tensor]: - if mask is not None: _mask_dtype = mask.dtype _mask_is_float = torch.is_floating_point(mask) if _mask_dtype != torch.bool and not _mask_is_float: raise AssertionError( - f"only bool and floating types of {mask_name} are supported") + f"only bool and floating types of {mask_name} are supported" + ) if check_other and other_type is not None: if _mask_dtype != other_type: warnings.warn( @@ -5195,12 +5820,12 @@ def _canonical_mask( "is deprecated. Use same type for both instead." ) if not _mask_is_float: - mask = ( - torch.zeros_like(mask, dtype=target_type) - .masked_fill_(mask, float("-inf")) + mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( + mask, float("-inf") ) return mask + def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: if input is None: return None @@ -5208,6 +5833,7 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: return input.dtype raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + def multi_head_attention_forward( query: Tensor, key: Tensor, @@ -5312,7 +5938,17 @@ def multi_head_attention_forward( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ - tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) if has_torch_function(tens_ops): return handle_torch_function( multi_head_attention_forward, @@ -5344,7 +5980,9 @@ def multi_head_attention_forward( average_attn_weights=average_attn_weights, ) - is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + is_batched = _mha_shape_check( + query, key, value, key_padding_mask, attn_mask, num_heads + ) # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input # is batched, run the computation and before returning squeeze the @@ -5366,7 +6004,7 @@ def multi_head_attention_forward( mask_name="key_padding_mask", other_type=_none_or_dtype(attn_mask), other_name="attn_mask", - target_type=query.dtype + target_type=query.dtype, ) if is_causal and attn_mask is None: @@ -5397,36 +6035,60 @@ def multi_head_attention_forward( # longer causal. is_causal = False - assert embed_dim == embed_dim_to_check, \ - f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, torch.Tensor): # embed_dim can be a tensor when JIT tracing - head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert key.shape[:2] == value.shape[:2], \ - f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" else: - assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" # # compute in-projection # if not use_separate_proj_weight: - assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) else: - assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" - assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" - assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" if in_proj_bias is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = in_proj_bias.chunk(3) - q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + q, k, v = _in_projection( + query, + key, + value, + q_proj_weight, + k_proj_weight, + v_proj_weight, + b_q, + b_k, + b_v, + ) # prep attention mask @@ -5435,14 +6097,20 @@ def multi_head_attention_forward( if attn_mask.dim() == 2: correct_2d_size = (tgt_len, src_len) if attn_mask.shape != correct_2d_size: - raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) attn_mask = attn_mask.unsqueeze(0) elif attn_mask.dim() == 3: correct_3d_size = (bsz * num_heads, tgt_len, src_len) if attn_mask.shape != correct_3d_size: - raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) else: - raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) # add bias along batch dimension (currently second) if bias_k is not None and bias_v is not None: @@ -5466,26 +6134,34 @@ def multi_head_attention_forward( k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert static_k.size(0) == bsz * num_heads, \ - f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" - assert static_k.size(2) == head_dim, \ - f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + assert ( + static_k.size(0) == bsz * num_heads + ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert ( + static_k.size(2) == head_dim + ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" k = static_k if static_v is None: v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed - assert static_v.size(0) == bsz * num_heads, \ - f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" - assert static_v.size(2) == head_dim, \ - f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + assert ( + static_v.size(0) == bsz * num_heads + ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert ( + static_v.size(2) == head_dim + ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" v = static_v # add zero attention along batch dimension (now first) if add_zero_attn: zero_attn_shape = (bsz * num_heads, 1, head_dim) - k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 + ) if attn_mask is not None: attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: @@ -5496,10 +6172,15 @@ def multi_head_attention_forward( # merge key padding and attention masks if key_padding_mask is not None: - assert key_padding_mask.shape == (bsz, src_len), \ - f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" - key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ - expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + assert key_padding_mask.shape == ( + bsz, + src_len, + ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, src_len) + ) if attn_mask is None: attn_mask = key_padding_mask else: @@ -5517,10 +6198,14 @@ def multi_head_attention_forward( B, Nt, E = q.shape q_scaled = q * math.sqrt(1.0 / float(E)) - assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + assert not ( + is_causal and attn_mask is None + ), "FIXME: is_causal not implemented for need_weights" if attn_mask is not None: - attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + attn_output_weights = torch.baddbmm( + attn_mask, q_scaled, k.transpose(-2, -1) + ) else: attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) attn_output_weights = softmax(attn_output_weights, dim=-1) @@ -5529,7 +6214,9 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) @@ -5557,8 +6244,12 @@ def multi_head_attention_forward( k = k.view(bsz, num_heads, src_len, head_dim) v = v.view(bsz, num_heads, src_len, head_dim) - attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) - attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + attn_output = scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) diff --git a/torch/overrides.py b/torch/overrides.py index 651912bc4202d5..d005368e8dbbd9 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -59,7 +59,9 @@ def _disable_user_warnings( - func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch" + func: Callable, + regex: str = ".*is deprecated, please use.*", + module: str = "torch", ) -> Callable: """ Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the @@ -1582,7 +1584,8 @@ def wrapped(*args, **kwargs): def _get_overloaded_args( - relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None + relevant_args: Iterable[Any], + get_type_fn: Callable[[Any], Type] = None, ) -> List[Any]: """Returns a list of arguments on which to call __torch_function__. @@ -1659,7 +1662,10 @@ def _get_overloaded_args( def handle_torch_function( - public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs + public_api: Callable, + relevant_args: Iterable[Any], + *args, + **kwargs, ) -> Any: """Implement a function with checks for ``__torch_function__`` overrides. @@ -1794,7 +1800,7 @@ def handle_torch_function( @functools.lru_cache(None) def _get_overridable_functions() -> ( - Tuple[Dict[Any, List[Callable]], Dict[Callable, str]] + Tuple[Dict[Any, List[Callable]], Dict[Callable, str]], ): overridable_funcs = collections.defaultdict(list) index = {} diff --git a/torch/serialization.py b/torch/serialization.py index 738af26728e83f..95d8d2e5cc67b7 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -273,7 +273,9 @@ def register_package( def check_module_version_greater_or_equal( - module, req_version_tuple, error_if_malformed=True + module, + req_version_tuple, + error_if_malformed=True, ): """ Check if a module's version satisfies requirements @@ -424,7 +426,9 @@ def _deserialize(backend_name, obj, location): register_package(10, _cpu_tag, _cpu_deserialize) register_package( - 20, functools.partial(_backend_tag, "cuda"), functools.partial(_deserialize, "cuda") + 20, + functools.partial(_backend_tag, "cuda"), + functools.partial(_deserialize, "cuda"), ) register_package(21, _mps_tag, _mps_deserialize) register_package(22, _meta_tag, _meta_deserialize) @@ -434,15 +438,19 @@ def _deserialize(backend_name, obj, location): functools.partial(_deserialize, "privateuse1"), ) register_package( - 24, functools.partial(_backend_tag, "hpu"), functools.partial(_deserialize, "hpu") + 24, + functools.partial(_backend_tag, "hpu"), + functools.partial(_deserialize, "hpu"), ) register_package( - 25, functools.partial(_backend_tag, "xpu"), functools.partial(_deserialize, "xpu") + 25, + functools.partial(_backend_tag, "xpu"), + functools.partial(_deserialize, "xpu"), ) def location_tag( - storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] + storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage], ): for _, tagger, _ in _package_registry: location = tagger(storage) From a87d82abd746240e7b46b992fa9df7ae6d3e6d4a Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 17 Jun 2024 03:58:20 +0800 Subject: [PATCH 084/171] [BE] enable UFMT for `torch/nn/*.py` (#128593) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128593 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #128596, #128594, #128592 --- .lintrunner.toml | 7 - torch/__init__.py | 4 +- torch/nn/__init__.py | 25 ++- torch/nn/_reduction.py | 35 ++-- torch/nn/common_types.py | 8 +- torch/nn/grad.py | 182 +++++++++++++++---- torch/nn/init.py | 104 +++++++---- torch/nn/parameter.py | 70 +++---- torch/utils/data/dataloader.py | 7 +- torch/utils/data/datapipes/_hook_iterator.py | 2 +- 10 files changed, 304 insertions(+), 140 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index ac5bbae1302170..08e434e8f143ba 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1643,12 +1643,6 @@ exclude_patterns = [ 'torch/linalg/__init__.py', 'torch/monitor/__init__.py', 'torch/nested/__init__.py', - 'torch/nn/__init__.py', - 'torch/nn/_reduction.py', - 'torch/nn/common_types.py', - 'torch/nn/cpp.py', - 'torch/nn/grad.py', - 'torch/nn/init.py', 'torch/nn/intrinsic/__init__.py', 'torch/nn/intrinsic/modules/__init__.py', 'torch/nn/intrinsic/modules/fused.py', @@ -1665,7 +1659,6 @@ exclude_patterns = [ 'torch/nn/intrinsic/quantized/modules/bn_relu.py', 'torch/nn/intrinsic/quantized/modules/conv_relu.py', 'torch/nn/intrinsic/quantized/modules/linear_relu.py', - 'torch/nn/parameter.py', 'torch/nn/qat/__init__.py', 'torch/nn/qat/dynamic/__init__.py', 'torch/nn/qat/dynamic/modules/__init__.py', diff --git a/torch/__init__.py b/torch/__init__.py index 3253adcfb04296..e50844bafc43a4 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1956,7 +1956,7 @@ def _manager_path(): ################################################################################ # needs to be after the above ATen bindings so we can overwrite from Python side -from torch import functional as functional # usort: skip +from torch import _VF as _VF, functional as functional # usort: skip from torch.functional import * # usort: skip # noqa: F403 ################################################################################ @@ -2054,7 +2054,7 @@ def compiled_with_cxx11_abi() -> builtins.bool: return _C._GLIBCXX_USE_CXX11_ABI -import torch._library +from torch import _library as _library, _ops as _ops # Import the ops "namespace" from torch._classes import classes as classes diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 23447d48440920..4597d382d5c663 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,15 +1,20 @@ # mypy: allow-untyped-defs -from .modules import * # noqa: F403 -from .parameter import ( +from torch.nn.modules import * # noqa: F403 +from torch.nn import ( + attention as attention, + functional as functional, + init as init, + modules as modules, + parallel as parallel, + parameter as parameter, + utils as utils, +) +from torch.nn.parallel import DataParallel as DataParallel +from torch.nn.parameter import ( Parameter as Parameter, - UninitializedParameter as UninitializedParameter, UninitializedBuffer as UninitializedBuffer, + UninitializedParameter as UninitializedParameter, ) -from .parallel import DataParallel as DataParallel -from . import init -from . import functional -from . import utils -from . import attention def factory_kwargs(kwargs): @@ -48,7 +53,9 @@ def __init__(self, **kwargs): for k in simple_keys: if k in kwargs: if k in r: - raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs") + raise TypeError( + f"{k} specified twice, in **kwargs and in factory_kwargs" + ) r[k] = kwargs[k] return r diff --git a/torch/nn/_reduction.py b/torch/nn/_reduction.py index ac2a8bb0a0e9ed..93b00dc6feb43d 100644 --- a/torch/nn/_reduction.py +++ b/torch/nn/_reduction.py @@ -1,30 +1,39 @@ -from typing import Optional import warnings +from typing import Optional + # NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h def get_enum(reduction: str) -> int: - if reduction == 'none': + if reduction == "none": ret = 0 - elif reduction == 'mean': + elif reduction == "mean": ret = 1 - elif reduction == 'elementwise_mean': - warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.") + elif reduction == "elementwise_mean": + warnings.warn( + "reduction='elementwise_mean' is deprecated. " + "Please use reduction='mean' instead." + ) ret = 1 - elif reduction == 'sum': + elif reduction == "sum": ret = 2 else: ret = -1 # TODO: remove once JIT exceptions support control flow raise ValueError(f"{reduction} is not a valid value for reduction") return ret + # In order to support previous versions, accept boolean size_average and reduce # and convert them into the new constants for now # We use these functions in torch/legacy as well, in which case we'll silence the warning -def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: +def legacy_get_string( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> str: warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." if size_average is None: @@ -33,15 +42,19 @@ def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit reduce = True if size_average and reduce: - ret = 'mean' + ret = "mean" elif reduce: - ret = 'sum' + ret = "sum" else: - ret = 'none' + ret = "none" if emit_warning: warnings.warn(warning.format(ret)) return ret -def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: +def legacy_get_enum( + size_average: Optional[bool], + reduce: Optional[bool], + emit_warning: bool = True, +) -> int: return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/torch/nn/common_types.py b/torch/nn/common_types.py index 884f739e27813a..74661d604c3e60 100644 --- a/torch/nn/common_types.py +++ b/torch/nn/common_types.py @@ -1,12 +1,14 @@ -from typing import TypeVar, Union, Tuple, Optional -from .. import Tensor +from typing import Optional, Tuple, TypeVar, Union + +from torch import Tensor + # Create some useful type aliases # Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally # broadcast to a tuple. # Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations. -T = TypeVar('T') +T = TypeVar("T") _scalar_or_tuple_any_t = Union[T, Tuple[T, ...]] _scalar_or_tuple_1_t = Union[T, Tuple[T]] _scalar_or_tuple_2_t = Union[T, Tuple[T, T]] diff --git a/torch/nn/grad.py b/torch/nn/grad.py index dbd38fcdd38cd4..61e817dbed612e 100644 --- a/torch/nn/grad.py +++ b/torch/nn/grad.py @@ -2,10 +2,18 @@ """Gradient interface.""" import torch -from .modules.utils import _single, _pair, _triple - - -def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): +from torch.nn.modules.utils import _pair, _single, _triple + + +def conv1d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv1d with respect to the input of the convolution. This is same as the 1D transposed convolution operator under the hood but requires @@ -32,12 +40,30 @@ def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= """ input = grad_output.new_empty(1).expand(input_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _single(stride), _single(padding), _single(dilation), - False, [0], groups, (True, False, False))[0] - - -def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv1d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv1d with respect to the weight of the convolution. Args: @@ -62,12 +88,30 @@ def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation """ weight = grad_output.new_empty(1).expand(weight_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _single(stride), _single(padding), _single(dilation), - False, [0], groups, (False, True, False))[1] - - -def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _single(stride), + _single(padding), + _single(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv2d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv2d with respect to the input of the convolution. This is same as the 2D transposed convolution operator under the hood but requires @@ -94,12 +138,30 @@ def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= """ input = grad_output.new_empty(1).expand(input_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _pair(stride), _pair(padding), _pair(dilation), - False, [0], groups, (True, False, False))[0] - - -def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv2d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv2d with respect to the weight of the convolution. Args: @@ -124,12 +186,30 @@ def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation """ weight = grad_output.new_empty(1).expand(weight_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _pair(stride), _pair(padding), _pair(dilation), - False, [0], groups, (False, True, False))[1] - - -def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1): + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _pair(stride), + _pair(padding), + _pair(dilation), + False, + [0], + groups, + (False, True, False), + )[1] + + +def conv3d_input( + input_size, + weight, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv3d with respect to the input of the convolution. This is same as the 3D transposed convolution operator under the hood but requires @@ -156,12 +236,30 @@ def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation= """ input = grad_output.new_empty(1).expand(input_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _triple(stride), _triple(padding), _triple(dilation), - False, [0], groups, (True, False, False))[0] - - -def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (True, False, False), + )[0] + + +def conv3d_weight( + input, + weight_size, + grad_output, + stride=1, + padding=0, + dilation=1, + groups=1, +): r"""Compute the gradient of conv3d with respect to the weight of the convolution. Args: @@ -185,6 +283,16 @@ def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation """ weight = grad_output.new_empty(1).expand(weight_size) - return torch.ops.aten.convolution_backward(grad_output, input, weight, None, - _triple(stride), _triple(padding), _triple(dilation), - False, [0], groups, (False, True, False))[1] + return torch.ops.aten.convolution_backward( + grad_output, + input, + weight, + None, + _triple(stride), + _triple(padding), + _triple(dilation), + False, + [0], + groups, + (False, True, False), + )[1] diff --git a/torch/nn/init.py b/torch/nn/init.py index b3179abb49371d..f01edaee50adc0 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -2,10 +2,11 @@ """This file contains utilities for initializing neural network parameters.""" import math import warnings +from typing import Optional as _Optional -from torch import Tensor import torch -from typing import Optional as _Optional +from torch import Tensor + # These no_grad_* functions are necessary as wrappers around the parts of these # functions that use `with torch.no_grad()`. The JIT doesn't support context @@ -25,12 +26,14 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function - return (1. + math.erf(x / math.sqrt(2.))) / 2. + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2) + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and @@ -48,7 +51,7 @@ def norm_cdf(x): tensor.erfinv_() # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.)) + tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range @@ -100,24 +103,38 @@ def calculate_gain(nonlinearity, param=None): .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html """ - linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] - if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": return 1 - elif nonlinearity == 'tanh': + elif nonlinearity == "tanh": return 5.0 / 3 - elif nonlinearity == 'relu': + elif nonlinearity == "relu": return math.sqrt(2.0) - elif nonlinearity == 'leaky_relu': + elif nonlinearity == "leaky_relu": if param is None: negative_slope = 0.01 - elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): # True/False are instances of int, hence check above negative_slope = param else: raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope ** 2)) - elif nonlinearity == 'selu': - return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) else: raise ValueError(f"Unsupported nonlinearity {nonlinearity}") @@ -175,13 +192,14 @@ def normal_( ) return _no_grad_normal_(tensor, mean, std, generator) + def trunc_normal_( tensor: Tensor, - mean: float = 0., - std: float = 1., - a: float = -2., - b: float = 2., - generator: _Optional[torch.Generator] = None + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input Tensor with values drawn from a truncated normal distribution. @@ -218,7 +236,9 @@ def constant_(tensor: Tensor, val: float) -> Tensor: >>> nn.init.constant_(w, 0.3) """ if torch.overrides.has_torch_function_variadic(tensor): - return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) + return torch.overrides.handle_torch_function( + constant_, (tensor,), tensor=tensor, val=val + ) return _no_grad_fill_(tensor, val) @@ -232,7 +252,7 @@ def ones_(tensor: Tensor) -> Tensor: >>> w = torch.empty(3, 5) >>> nn.init.ones_(w) """ - return _no_grad_fill_(tensor, 1.) + return _no_grad_fill_(tensor, 1.0) def zeros_(tensor: Tensor) -> Tensor: @@ -292,7 +312,7 @@ def dirac_(tensor, groups=1): sizes = tensor.size() if sizes[0] % groups != 0: - raise ValueError('dim 0 must be divisible by groups') + raise ValueError("dim 0 must be divisible by groups") out_chans_per_grp = sizes[0] // groups min_dim = min(out_chans_per_grp, sizes[1]) @@ -305,18 +325,29 @@ def dirac_(tensor, groups=1): if dimensions == 3: # Temporal convolution tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 elif dimensions == 4: # Spatial convolution - tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, - tensor.size(3) // 2] = 1 + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + ] = 1 else: # Volumetric convolution - tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, - tensor.size(3) // 2, tensor.size(4) // 2] = 1 + tensor[ + g * out_chans_per_grp + d, + d, + tensor.size(2) // 2, + tensor.size(3) // 2, + tensor.size(4) // 2, + ] = 1 return tensor def _calculate_fan_in_and_fan_out(tensor): dimensions = tensor.dim() if dimensions < 2: - raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) num_input_fmaps = tensor.size(1) num_output_fmaps = tensor.size(0) @@ -333,7 +364,9 @@ def _calculate_fan_in_and_fan_out(tensor): def xavier_uniform_( - tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, ) -> Tensor: r"""Fill the input `Tensor` with values using a Xavier uniform distribution. @@ -391,17 +424,17 @@ def xavier_normal_( fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) - return _no_grad_normal_(tensor, 0., std, generator) + return _no_grad_normal_(tensor, 0.0, std, generator) def _calculate_correct_fan(tensor, mode): mode = mode.lower() - valid_modes = ['fan_in', 'fan_out'] + valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == 'fan_in' else fan_out + return fan_in if mode == "fan_in" else fan_out def kaiming_uniform_( @@ -447,7 +480,8 @@ def kaiming_uniform_( a=a, mode=mode, nonlinearity=nonlinearity, - generator=generator) + generator=generator, + ) if 0 in tensor.shape: warnings.warn("Initializing zero-element tensors is a no-op") @@ -607,7 +641,7 @@ def deprecated_init(*args, **kwargs): ) return meth(*args, **kwargs) - deprecated_init.__doc__ = fr""" + deprecated_init.__doc__ = rf""" {old_name}(...) .. warning:: diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 43d4f1cf40008b..3eb1083905d058 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -1,13 +1,16 @@ +from collections import OrderedDict + import torch from torch._C import _disabled_torch_function_impl -from collections import OrderedDict + # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. def __instancecheck__(self, instance): return super().__instancecheck__(instance) or ( - isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False)) + isinstance(instance, torch.Tensor) and getattr(instance, "_is_param", False) + ) class Parameter(torch.Tensor, metaclass=_ParameterMeta): @@ -42,11 +45,13 @@ def __new__(cls, data=None, requires_grad=True): # Path for custom tensors: set a flag on the instance to indicate parameter-ness. t = data.detach().requires_grad_(requires_grad) if type(t) is not type(data): - raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} " - "requires that detach() returns an instance of the same type, but return " - f"type {type(t).__name__} was found instead. To use the type as a " - "Parameter, please correct the detach() semantics defined by " - "its __torch_dispatch__() implementation.") + raise RuntimeError( + f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation." + ) t._is_param = True return t @@ -56,12 +61,14 @@ def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: - result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad) + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) memo[id(self)] = result return result def __repr__(self): - return 'Parameter containing:\n' + super().__repr__() + return "Parameter containing:\n" + super().__repr__() def __reduce_ex__(self, proto): state = torch._utils._get_obj_state(self) @@ -71,12 +78,12 @@ def __reduce_ex__(self, proto): if not state: return ( torch._utils._rebuild_parameter, - (self.data, self.requires_grad, hooks) + (self.data, self.requires_grad, hooks), ) return ( torch._utils._rebuild_parameter_with_state, - (self.data, self.requires_grad, hooks, state) + (self.data, self.requires_grad, hooks, state), ) __torch_function__ = _disabled_torch_function_impl @@ -127,41 +134,41 @@ def materialize(self, shape, device=None, dtype=None): @property def shape(self): raise RuntimeError( - 'Can\'t access the shape of an uninitialized parameter or buffer. ' - 'This error usually happens in `load_state_dict` when trying to load ' - 'an uninitialized parameter into an initialized one. ' - 'Call `forward` to initialize the parameters before accessing their attributes.') + "Can't access the shape of an uninitialized parameter or buffer. " + "This error usually happens in `load_state_dict` when trying to load " + "an uninitialized parameter into an initialized one. " + "Call `forward` to initialize the parameters before accessing their attributes." + ) def share_memory_(self): raise RuntimeError( - 'Can\'t share memory on an uninitialized parameter or buffer. ' - 'Call `forward` to initialize the parameters before calling ' - '`module.share_memory()`.') + "Can't share memory on an uninitialized parameter or buffer. " + "Call `forward` to initialize the parameters before calling " + "`module.share_memory()`." + ) def __repr__(self): - return f'<{self.__class__.__name__}>' + return f"<{self.__class__.__name__}>" def __reduce_ex__(self, proto): # See Note [Don't serialize hooks] - return ( - self.__class__, - (self.requires_grad,) - ) + return (self.__class__, (self.requires_grad,)) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): # method-wrapper is to detect access to Tensor properties that are # wrapped in descriptors - if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper': + if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) raise ValueError( - f'Attempted to use an uninitialized parameter in {func}. ' - 'This error happens when you are using a `LazyModule` or ' - f'explicitly manipulating `torch.nn.parameter.{cls.__name__}` ' - 'objects. When using LazyModules Call `forward` with a dummy batch ' - 'to initialize the parameters before calling torch functions') + f"Attempted to use an uninitialized parameter in {func}. " + "This error happens when you are using a `LazyModule` or " + f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " + "objects. When using LazyModules Call `forward` with a dummy batch " + "to initialize the parameters before calling torch functions" + ) def is_lazy(param): @@ -187,7 +194,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter): cls_to_become = Parameter def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -199,6 +206,7 @@ def __deepcopy__(self, memo): memo[id(self)] = result return result + class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): r"""A buffer that is not initialized. @@ -218,6 +226,6 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): cls_to_become = torch.Tensor def __new__(cls, requires_grad=False, device=None, dtype=None) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) return torch.Tensor._make_subclass(cls, data, requires_grad) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 9ad0db898a0479..95e4240cefd7ed 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -19,7 +19,6 @@ import multiprocessing as python_multiprocessing import torch import torch.distributed as dist -import torch.multiprocessing as multiprocessing import torch.utils.data.graph_settings from torch._utils import ExceptionWrapper @@ -396,13 +395,13 @@ def multiprocessing_context(self, multiprocessing_context): if multiprocessing_context is not None: if self.num_workers > 0: if isinstance(multiprocessing_context, str): - valid_start_methods = multiprocessing.get_all_start_methods() + valid_start_methods = torch.multiprocessing.get_all_start_methods() if multiprocessing_context not in valid_start_methods: raise ValueError( 'multiprocessing_context option ' f'should specify a valid start method in {valid_start_methods!r}, but got ' f'multiprocessing_context={multiprocessing_context!r}') - multiprocessing_context = multiprocessing.get_context(multiprocessing_context) + multiprocessing_context = torch.multiprocessing.get_context(multiprocessing_context) if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): raise TypeError('multiprocessing_context option should be a valid context ' @@ -995,7 +994,7 @@ def __init__(self, loader): assert self._prefetch_factor > 0 if loader.multiprocessing_context is None: - multiprocessing_context = multiprocessing + multiprocessing_context = torch.multiprocessing else: multiprocessing_context = loader.multiprocessing_context diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 00b44cbede6163..b45bd8b00805b6 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -3,7 +3,7 @@ import functools from enum import Enum -import torch.autograd +import torch class _SnapshotState(Enum): From 316b7296771c87637231f9d90e2658aa1d629859 Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Mon, 17 Jun 2024 16:42:43 +0000 Subject: [PATCH 085/171] [Fix] TS converter constant to tensor (#128442) #### Issue Tensor constant was previously lifted directly as an input in the fx graph, which results errors for multiple test cases with tensor constant. This PR introduces a fix to convert tensor constant to a `GetAttr` in the fx graph. This PR also introduces other fixes to maintain a valid `state_dict` for exported program when there are tensor constants. In short, after tensor constants are converted as `GetAttr`, they are treated as buffers during retracing. The fix will convert those back from buffer to constant. #### Test Plan Add new test cases that generate tensor constants * `pytest test/export/test_converter.py -s -k test_implicit_constant_to_tensor_handling` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128442 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 228 +++++++++++++++++++++++----------- torch/_export/converter.py | 64 +++++++--- torch/jit/_trace.py | 5 +- 3 files changed, 206 insertions(+), 91 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 8d26dc7a22ad77..300f70223a26b7 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: export"] import unittest -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Union import torch @@ -16,28 +16,45 @@ class TestConverter(TestCase): - def _check_equal_ts_ep_converter(self, mod, inp) -> ExportedProgram: - ts_model = torch.jit.script(mod) - ep = TS2EPConverter(ts_model, inp).convert() - ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) - orig_out, _ = pytree.tree_flatten(mod(*inp)) - - # Check module. - if isinstance(mod, torch.nn.Module): - self.assertEqual( - ep.module().state_dict().keys(), - mod.state_dict().keys(), - ) - - # Check results. - self.assertEqual(len(ep_out), len(orig_out)) - for ep_t, orig_t in zip(ep_out, orig_out): - if isinstance(ep_t, torch.Tensor): - self.assertEqual(ep_t.shape, orig_t.shape) - self.assertTrue(torch.allclose(ep_t, orig_t)) + def _check_equal_ts_ep_converter( + self, mod, inp, option: Union[List[str]] = None + ) -> ExportedProgram: + # By default, it tests both jit.trace and jit.script. + if option is None: + option = ["trace", "script"] + + model_list = [] + for opt in option: + if opt == "script": + ts_model = torch.jit.script(mod) else: - self.assertEqual(ep_t, orig_t) - return ep + ts_model = torch.jit.trace(mod, inp) + model_list.append(ts_model) + + ep_list = [] + for ts_model in model_list: + ep = TS2EPConverter(ts_model, inp).convert() + ep_list.append(ep) + ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) + orig_out, _ = pytree.tree_flatten(ts_model(*inp)) + + # Check module. + if isinstance(mod, torch.nn.Module): + self.assertEqual( + ep.state_dict.keys(), + ts_model.state_dict().keys(), + ) + + # Check results. + self.assertEqual(len(ep_out), len(orig_out)) + for ep_t, orig_t in zip(ep_out, orig_out): + if isinstance(ep_t, torch.Tensor) and isinstance(orig_t, torch.Tensor): + self.assertEqual(ep_t.shape, orig_t.shape) + self.assertTrue(torch.allclose(ep_t, orig_t)) + else: + self.assertEqual(type(ep_t), type(orig_t)) + self.assertEqual(ep_t, orig_t) + return ep_list def test_ts2ep_converter_basic(self): class MSingle(torch.nn.Module): @@ -78,9 +95,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): inp = (torch.tensor(4), torch.tensor(4)) - self._check_equal_ts_ep_converter(MOutputList(), inp) + # Traced function must use immutable structure as output. + self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"]) self._check_equal_ts_ep_converter(MOutputTuple(), inp) - self._check_equal_ts_ep_converter(MOutputDict(), inp) + self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"]) def test_aten_dim(self): class Module(torch.nn.Module): @@ -171,12 +189,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return y + y inp = (torch.tensor(True), torch.tensor(4)) - ep = self._check_equal_ts_ep_converter(M(), inp) + ep_list = self._check_equal_ts_ep_converter(M(), inp) - torch.testing.assert_close( - ep.module()(torch.tensor(False), torch.tensor(4)), - M()(torch.tensor(False), torch.tensor(4)), - ) + for ep in ep_list[1:]: + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) def test_convert_if_multiple_out(self): class M(torch.nn.Module): @@ -197,12 +216,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): return res[0] + res[1] inp = (torch.tensor(True), torch.tensor(4)) - ep = self._check_equal_ts_ep_converter(M(), inp) + ep_list = self._check_equal_ts_ep_converter(M(), inp) - torch.testing.assert_close( - ep.module()(torch.tensor(False), torch.tensor(4)), - M()(torch.tensor(False), torch.tensor(4)), - ) + for ep in ep_list[1:]: + torch.testing.assert_close( + ep.module()(torch.tensor(False), torch.tensor(4)), + M()(torch.tensor(False), torch.tensor(4)), + ) def test_profiler__record_function(self): class Module(torch.nn.Module): @@ -231,8 +251,9 @@ def forward( z = x + 1 return x is y, z + # Traced function must return output that has tensors. inp = (torch.randn(10, 10), torch.rand(10, 10)) - self._check_equal_ts_ep_converter(Module(), inp) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) def test_aten___isnot__(self): class Module(torch.nn.Module): @@ -242,8 +263,9 @@ def forward( z = x + 1 return x is not y, z + # Traced function must return output that has tensors. inp = (torch.randn(10, 10), torch.rand(10, 10)) - self._check_equal_ts_ep_converter(Module(), inp) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) def test_aten___not__(self): class Module(torch.nn.Module): @@ -253,8 +275,9 @@ def forward( z = x + 1 return not (x is not y), z + # Traced function must return output that has tensors. inp = (torch.randn(10, 10), torch.rand(10, 10)) - self._check_equal_ts_ep_converter(Module(), inp) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) def test_ts2ep_converter_unpack(self): class MUnpackList(torch.nn.Module): @@ -302,9 +325,9 @@ def forward(self, x: torch.Tensor): inp = (torch.ones(3),) orig_m = NestedM(3) - ep = self._check_equal_ts_ep_converter(orig_m, inp) + self._check_equal_ts_ep_converter(orig_m, inp) orig_m = SuperNestedM(3) - ep = self._check_equal_ts_ep_converter(orig_m, inp) + self._check_equal_ts_ep_converter(orig_m, inp) def test_convert_nn_module_with_nested_buffer(self): class M(torch.nn.Module): @@ -335,9 +358,9 @@ def forward(self, x: torch.Tensor): inp = (torch.ones(1),) orig_m = NestedM() - ep = self._check_equal_ts_ep_converter(orig_m, inp) + self._check_equal_ts_ep_converter(orig_m, inp) orig_m = SuperNestedM() - ep = self._check_equal_ts_ep_converter(orig_m, inp) + self._check_equal_ts_ep_converter(orig_m, inp) def test_convert_nn_module_with_nested_if_and_buffer(self): class M(torch.nn.Module): @@ -379,14 +402,16 @@ def forward(self, x: torch.Tensor): # Super nested module testing. inp = (torch.ones(1),) orig_m = SuperNestedM() - ep = self._check_equal_ts_ep_converter(orig_m, inp) + # TODO: fix trace: state_dict is not equal. + ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"]) t = inp[0] t -= 1 - torch.testing.assert_close( - ep.module()(*inp), - orig_m(*inp), - ) + for ep in ep_list: + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) def test_convert_nn_module_with_nested_if_and_param(self): class M(torch.nn.Module): @@ -443,38 +468,43 @@ def forward(self, x: torch.Tensor): # Basic module testing. inp = (torch.ones(3),) orig_m = M(3) - ep = self._check_equal_ts_ep_converter(orig_m, inp) + ep_list = self._check_equal_ts_ep_converter(orig_m, inp) t = inp[0] t -= 0.8 - torch.testing.assert_close( - ep.module()(*inp), - orig_m(*inp), - ) + for ep in ep_list[1:]: + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) # Nested module testing. inp = (torch.ones(3),) orig_m = NestedM(3) - ep = self._check_equal_ts_ep_converter(orig_m, inp) + # TODO: fix trace: state_dict is not equal. + ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"]) t = inp[0] t -= 0.8 - torch.testing.assert_close( - ep.module()(*inp), - orig_m(*inp), - ) + for ep in ep_list: + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) # Super nested module testing. inp = (torch.ones(3),) orig_m = SuperNestedM1(3) - ep = self._check_equal_ts_ep_converter(orig_m, inp) + # TODO: fix trace: state_dict is not equal. + ep_list = self._check_equal_ts_ep_converter(orig_m, inp, ["script"]) t = inp[0] t -= 0.8 - torch.testing.assert_close( - ep.module()(*inp), - orig_m(*inp), - ) + for ep in ep_list: + torch.testing.assert_close( + ep.module()(*inp), + orig_m(*inp), + ) # # Super nested module testing. # inp = (torch.ones(3),) @@ -501,14 +531,16 @@ class MTensorIn(torch.nn.Module): def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): return x in x_dict + # Traced function must return output that has tensors. inp = (torch.tensor(4),) - self._check_equal_ts_ep_converter(MIn(), inp) - self._check_equal_ts_ep_converter(MNotIn(), inp) + self._check_equal_ts_ep_converter(MIn(), inp, ["script"]) + self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"]) + # TODO: update test to use reference for in. inp = (torch.tensor(4), {torch.tensor(4): "foo"}) - self._check_equal_ts_ep_converter(MTensorIn(), inp) + self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) inp = (torch.tensor(1), {torch.tensor(4): "foo"}) - self._check_equal_ts_ep_converter(MTensorIn(), inp) + self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) def test_ts2ep_converter_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: @@ -562,14 +594,70 @@ def func2(x, y): ) self._check_equal_ts_ep_converter(func1, inp) - ep = self._check_equal_ts_ep_converter(func2, inp) + ep_list = self._check_equal_ts_ep_converter(func2, inp) t = inp[0] t -= 1 - torch.testing.assert_close( - ep.module()(*inp), - func2(*inp), - ) + for ep in ep_list[1:]: + torch.testing.assert_close( + ep.module()(*inp), + func2(*inp), + ) + + def test_implicit_constant_to_tensor_handling(self): + def func1(x): + return x + 2 + + def func2(x, y): + return x * y / (x - 2 * y) + y + + def func3(x): + return x + torch.tensor([3]) + + def func4(): + val = torch.tensor(float("inf")) + return torch.full((10, 10), val) + + def func5(): + x = -1 + return x * torch.ones(1, dtype=torch.float), torch.zeros( + 1, dtype=torch.float + ) + + def func6(x): + return x.numel() + + class M1(torch.nn.Module): + def __init__(self, value): + super().__init__() + self.x = torch.tensor(value) + + def forward(self): + return self.x.clone() + + class M2(torch.nn.Module): + def forward(self, x): + return torch.tensor(4) + x + + inp = (torch.randn([2, 2]),) + self._check_equal_ts_ep_converter(func1, inp) + inp = (torch.randn([2, 2]), torch.randn([2, 2])) + self._check_equal_ts_ep_converter(func2, inp) + + inp = (torch.randn([2, 2]),) + self._check_equal_ts_ep_converter(func3, inp) + + self._check_equal_ts_ep_converter(func4, ()) + self._check_equal_ts_ep_converter(M1(5), ()) + + inp = (torch.randn(2),) + self._check_equal_ts_ep_converter(M2(), inp) + + self._check_equal_ts_ep_converter(func5, ()) + # TODO: NumToTensor now returns a tensor based on dtype of input + # tensor, but it should always be Long. + # inp = (torch.randn([2, 3, 4]),) + # self._check_equal_ts_ep_converter(func6, inp) if __name__ == "__main__": diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 568685f9759735..2c54db38dee8b8 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -18,8 +18,6 @@ from torch.fx import subgraph_rewriter from torch.onnx.utils import _create_jit_graph -from torchgen.model import FunctionSchema - def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): def pattern(im, dim, scale): @@ -189,9 +187,9 @@ def _map_blocks_to_lifted_attrs(entry): def get_op_overload(node: torch._C.Node): schema_str = node.schema() - schema = FunctionSchema.parse(schema_str) - ns, op_name = str(schema.name.name).split("::") - override = schema.name.overload_name + schema = torch._C.parse_schema(schema_str) + ns, op_name = str(schema.name).split("::") + override = schema.overload_name try: op_overload_mod = getattr(torch.ops, ns) @@ -292,7 +290,12 @@ def convert(self) -> torch.fx.GraphModule: # Pass parameter and buffer to the root for lookup. gm = torch.fx.GraphModule( - {**self.subgraphs, **self.name_to_param_map, **self.name_to_buffer_map}, + { + **self.subgraphs, + **self.name_to_param_map, + **self.name_to_buffer_map, + **self.tensor_constants, + }, self.fx_graph, ) @@ -342,6 +345,20 @@ def convert_graph_inputs(self): self.name_to_node[name] = fx_node + def convert_aten_tensor(self, node: torch._C.Node): + """aten::tensor creates a constant tensor ad-hoc --> GetAttr""" + args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema) + for k in kwargs: + if k == "requires_grad": + kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True + tensor = torch.tensor(*args, **kwargs) + + output_name = node.output().debugName() + alias_name = f"lifted_tensor_{output_name}" + fx_node = self.fx_graph.get_attr(alias_name) + self.name_to_node[output_name] = fx_node + self.tensor_constants[alias_name] = tensor + def convert_prim_Constant(self, node: torch._C.Node): name = node.output().debugName() @@ -355,20 +372,11 @@ def convert_prim_Constant(self, node: torch._C.Node): elif constant_kind == "s": value = node.s("value") elif constant_kind == "t": - # lift tensor constant as a placeholder - placeholder_name = f"constant_{name}" - fx_node = self.fx_graph.placeholder(placeholder_name) - self.name_to_node[name] = fx_node - self.tensor_constants[placeholder_name] = node.t("value") - - self.input_specs.append( - InputSpec( - InputKind.CONSTANT_TENSOR, - arg=TensorArgument(name=placeholder_name), - target=placeholder_name, - ) + alias_name = ( + f"lifted_tensor_{name}" # Follow naming convention from EP tracing. ) - + fx_node = self.fx_graph.get_attr(alias_name) + self.tensor_constants[alias_name] = node.t("value") value = fx_node elif constant_kind == "ival": value = node.ival("value") @@ -730,7 +738,9 @@ def convert(self) -> ExportedProgram: ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants) return ep - def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants): + def retrace_as_exported_program( + self, gm: torch.fx.GraphModule, tensor_constants: Dict[str, torch.Tensor] + ): # TODO: adjust input orders to match GraphSignature convention ep = torch.export._trace._export( gm, @@ -738,4 +748,18 @@ def retrace_as_exported_program(self, gm: torch.fx.GraphModule, tensor_constants strict=False, pre_dispatch=True, ) + + # Post-processing to make sure the ExportedProgram states are correct. + # Because during conversion, we set tensor constants as GetAttr, + # retracing cannot recognize them as tensor constants but instead + # treat them as buffers. We need to set them again here. + ep._constants = tensor_constants + for k in tensor_constants: + ep.state_dict.pop(k, None) + for spec in ep.graph_signature.input_specs: + # Mark as constant tensors for erroneously traced buffers. + if spec.kind == InputKind.BUFFER and spec.target in tensor_constants: + spec.kind = InputKind.CONSTANT_TENSOR + ep.verifier().check(ep) + return ep diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 7db8560242878d..83cda96030e737 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -655,7 +655,10 @@ def analyze_ts_result_with_export_result(export, trace): if type(orig) != type(loaded): return False - if isinstance(orig, torch.Tensor): + if isinstance(orig, torch._subclasses.FakeTensor): + # Skip for FakeTensor. + return True + elif isinstance(orig, torch.Tensor): if orig.dtype != loaded.dtype: return False if not torch.allclose(orig, loaded): From 73b78d1cbefda55ec1723904b68660530d7d1495 Mon Sep 17 00:00:00 2001 From: Andrew Hoblitzell Date: Mon, 17 Jun 2024 16:44:15 +0000 Subject: [PATCH 086/171] Document the torch.nn.parallel.scatter_gather.gather function (#128566) Fixes #127899 ### Description Add docstring to `torch/nn/parallel/scatter_gather.py:gather` function Pull Request resolved: https://github.com/pytorch/pytorch/pull/128566 Approved by: https://github.com/kwen2501 --- torch/nn/parallel/scatter_gather.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 690f81f227855b..950b5a75b7480c 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -102,7 +102,18 @@ def scatter_kwargs( def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any: r"""Gather tensors from different GPUs on a specified device. - Use 'cpu' for CPU to avoid a deprecation warning. + This function is useful for gathering the results of a distributed computation. + It takes a sequence of objects, one for each GPU, and returns a single object + on the specified device. + + Args: + outputs (Any): A sequence of objects (potentially tensors) to gather. + target_device (Union[int, torch.device]): The device to gather the tensors to. + Use 'cpu' for CPU to avoid a deprecation warning. + dim (int, optional): The dimension along which to gather. Default: 0. + + Returns: + Any: A gathered object (potentially tensor) on the specified device. """ def gather_map(outputs): From fc2913fb808dd67667c4c57d01983a4dccec0f66 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 17 Jun 2024 16:48:00 +0000 Subject: [PATCH 087/171] Remove amax return from _scaled_mm (#128683) # Summary The primary reason for the change was lack of current use case and the need to work around an two Inductor issue. - Tensor arguments as kwarg only - multiple outputs from triton templates If the need for the amax return type arises we can consider either adding it, more likely creating a separate op. In principle PyTorch is moving away from ops that bundle lots of functionality into "mega ops". We instead rely upon the compiler to generate appropriate fused kernels. ### Changes: - This removes the amax return type from scaled_mm. We have found that the common use case is to return in "high-precision" ( a type with more precision than fp8). This is only relevant when returning in low-precision. - We currently still allow for fp8 returns and scaled result. Perhaps we should also ban this as well... New signature: ```Python def meta_scaled_mm( self: torch.Tensor, mat2: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_result: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ) -> torch.Tensor: ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128683 Approved by: https://github.com/vkuzo --- aten/src/ATen/native/cuda/Blas.cpp | 55 +++++------ aten/src/ATen/native/native_functions.yaml | 4 +- .../check_forward_backward_compatibility.py | 2 + test/test_matmul_cuda.py | 93 ++++++++++--------- torch/_meta_registrations.py | 10 +- .../aoti_torch/generated/c_shim_cuda.h | 2 +- .../csrc/inductor/aoti_torch/shim_common.cpp | 9 +- 7 files changed, 89 insertions(+), 86 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index f7997fe7271231..728f210b66ed01 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -885,15 +885,15 @@ static bool _scaled_mm_allowed_device() { // - `out`: a reference to the output tensor // - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace -std::tuple +Tensor& _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, const std::optional& bias, - std::optional out_dtype, - const std::optional& scale_a, - const std::optional& scale_b, const std::optional& scale_result, + std::optional out_dtype, bool use_fast_accum, - Tensor& out, Tensor& amax) { + Tensor& out) { // Check sizes bool allowed_device = _scaled_mm_allowed_device(); TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); @@ -902,9 +902,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK( mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), + TORCH_CHECK((scale_a.numel() == 1 && scale_a.scalar_type() == kFloat), "scale_a must be float scalar"); - TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), + TORCH_CHECK((scale_b.numel() == 1 && scale_b.scalar_type() == kFloat), "scale_b must be a float scalar"); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -922,7 +922,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat2.sizes()[1], " must be divisible by 16"); // Check types TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar"); TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); // Type restrictions imposed by CuBLASLt as of CUDA-12.1 @@ -940,23 +939,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } { auto bias_ = bias.value_or(Tensor()); - auto scale_a_ = scale_a.value_or(Tensor()); - auto scale_b_ = scale_b.value_or(Tensor()); auto scale_result_ = scale_result.value_or(Tensor()); - TensorArg targs[]{{out, "out", 0}, {amax, "amax", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}, - {bias_, "bias", 4}, {scale_a_, "scale_a", 5}, {scale_b_, "scale_b", 6}, - {scale_result_, "scale_result", 7}}; + + TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}, + {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5}, + {scale_result_, "scale_result", 6}}; checkAllSameGPU(__func__, targs); } IntArrayRef mat1_sizes = mat1.sizes(); IntArrayRef mat2_sizes = mat2.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - at::native::resize_output(amax, {}); cublasCommonArgs args(mat1, mat2, out); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); + + // Some scaled_gemms require an amax to populate lets create one here + Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float)); + #ifdef USE_ROCM auto tuning_ctx = at::cuda::tunable::getTuningContext(); if (tuning_ctx->IsTunableOpEnabled()) { @@ -999,11 +1000,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.n = args.n; params.k = args.k; params.a = args.mata->data_ptr(); - params.a_scale_ptr = scale_a ? scale_a->data_ptr() : nullptr; + params.a_scale_ptr = scale_a.data_ptr(); params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); params.b = args.matb->data_ptr(); - params.b_scale_ptr = scale_b ? scale_b->data_ptr() : nullptr; + params.b_scale_ptr = scale_b.data_ptr(); params.ldb = args.ldb; params.b_dtype = args.matb->scalar_type(); params.bias_ptr = bias ? bias->data_ptr(): nullptr; @@ -1048,11 +1049,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.n, args.k, args.mata->data_ptr(), - scale_a ? scale_a->data_ptr() : nullptr, + scale_a.data_ptr(), args.lda, args.mata->scalar_type(), args.matb->data_ptr(), - scale_b ? scale_b->data_ptr() : nullptr, + scale_b.data_ptr(), args.ldb, args.matb->scalar_type(), bias ? bias->data_ptr(): nullptr, @@ -1069,26 +1070,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, use_fast_accum); } -#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 - // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately - amax = at::max(at::abs(out.to(kFloat))); -#endif - - return {out, amax}; + return out; } -std::tuple +Tensor _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, + const Tensor& scale_a, + const Tensor& scale_b, const std::optional& bias, - std::optional out_dtype, - const std::optional& scale_a, - const std::optional& scale_b, const std::optional& scale_result, + std::optional out_dtype, bool use_fast_accum) { const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); - Tensor amax = at::empty({0}, mat_a.options().dtype(ScalarType::Float)); - return _scaled_mm_out_cuda(mat_a, mat_b, bias, out_dtype, scale_a, scale_b, scale_result, use_fast_accum, out, amax); + return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0715714a4d2d71..7474e0bc55d8b8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6994,12 +6994,12 @@ structured_delegate: _addmm_activation.out variants: function, method -- func: _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor) +- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: CUDA: _scaled_mm_cuda -- func: _scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!)) +- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CUDA: _scaled_mm_out_cuda diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 88927e8bf7ce50..189155f69f86ff 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -141,6 +141,8 @@ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)), + ("aten::_scaled_mm", datetime.date(2024, 12, 31)), # BC-breaking change in can_cast signature: 'from' -> 'from_' ("aten::can_cast", datetime.date(2024, 5, 31)), ] diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a5c583580848d9..83e0f9e80a6856 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -3,7 +3,7 @@ import unittest from itertools import product from functools import partial -from typing import Optional, Tuple +from typing import Optional import torch @@ -260,13 +260,13 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): amax = torch.max(torch.abs(x)) return amax_to_scale(amax, float8_dtype, x.dtype) -def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): +def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale y_fp32 = y.to(torch.float) / y_scale out_fp32 = torch.mm(x_fp32, y_fp32) - return out_fp32.to(out_dtype), torch.max(torch.abs(out_fp32)) + return out_fp32.to(out_dtype) def addmm_float8_unwrapped( a_data: torch.Tensor, @@ -276,31 +276,31 @@ def addmm_float8_unwrapped( output_dtype: torch.dtype, output_scale: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 - output, output_amax = torch._scaled_mm( + output = torch._scaled_mm( a_data, b_data, - out_dtype=output_dtype, scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, + out_dtype=output_dtype, ) output += bias - return output, output_amax - output, output_amax = torch._scaled_mm( + return output + output = torch._scaled_mm( a_data, b_data, bias=bias, - out_dtype=output_dtype, scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, + out_dtype=output_dtype, ) - return output, output_amax + return output def mm_float8( a: torch.Tensor, @@ -309,7 +309,7 @@ def mm_float8( b_scale: torch.Tensor, output_dtype: torch.dtype, # output dtype output_scale: Optional[torch.Tensor] = None, # output scale, precomputed -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: return addmm_float8_unwrapped( a, a_scale, b, b_scale, output_dtype, output_scale ) @@ -342,9 +342,9 @@ def to_fp8_saturated( x_scaled = x * x_scale if fp8_dtype == e4m3_type: - x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + x = x_scaled.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: - x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + x = x_scaled.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) else: raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") @@ -364,11 +364,11 @@ def _test_tautological_mm(self, device: str = "cuda", x_fp8 = torch.rand(size, size, device=device).to(x_dtype) y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) - (out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) if out_dtype is not None: self.assertEqual(out_dtype, out_fp8.dtype) - if out_dtype not in [torch.float16, torch.bfloat16, torch.float]: - self.assertEqual(out_fp32.amax(), amax_fp8) self.assertEqual(out_fp32, out_fp8.to(torch.float)) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) @@ -399,9 +399,9 @@ def test_float8_scale(self, device) -> None: y = torch.full(size, .5, device=device, dtype=y_type).t() scale_a = torch.tensor(1.5, device=device) scale_b = torch.tensor(0.66, device=device) - out_fp8, amax_fp8 = torch._scaled_mm(x, y) + out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) - out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) + out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) @@ -418,11 +418,11 @@ def test_scaled_mm_vs_emulated(self, base_dtype): x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() - x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) - y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) + x_fp8 = to_fp8_saturated(x, x_scale, input_dtype) + y_fp8 = to_fp8_saturated(y, y_scale, input_dtype) # Calculate actual F8 mm - out_scaled_mm, output_amax_scaled = mm_float8( + out_scaled_mm = mm_float8( x_fp8, y_fp8, a_scale=x_scale, @@ -431,7 +431,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype): ) # Calculate emulated F8 mm - out_emulated, output_amax_emulated = mm_float8_emulated( + out_emulated = mm_float8_emulated( x_fp8, x_scale, y_fp8, @@ -441,14 +441,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype): if output_dtype != base_dtype: out_scaled_mm = out_scaled_mm.to(compare_type) - out_emulated = out_emulated.to(compare_type) + out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) - out_scaled_mm = out_scaled_mm / amax_to_scale( - output_amax_scaled, input_dtype - ) - out_emulated = out_emulated / amax_to_scale( - output_amax_emulated, input_dtype - ) + out_emulated = out_emulated.to(compare_type) + out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 @@ -460,24 +456,30 @@ def test_scaled_mm_vs_emulated(self, base_dtype): @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def test_float8_bias(self, device) -> None: (k, l, m) = (16, 48, 32) - x = torch.rand((k, l), device=device).to(e4m3_type) + x = torch.ones((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() bias = torch.full((m,), 4.0, device=device, dtype=torch.half) - out_fp8, amax_fp8 = torch._scaled_mm(x, y) - outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) + outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) # this fails on ROCm currently because hipblaslt doesn't have amax op - if torch.version.hip is None: - self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0) + out_fp32 = out_fp8.to(torch.float32) + outb_fp32 = outb_fp8.to(torch.float32) + difference = torch.abs(out_fp32 - outb_fp32) + self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) @parametrize("bias", [True, False]) - def test_non_divisible_leading_dim(self, device, bias: torch.bool) -> None: + def test_non_divisible_leading_dim(self, device, bias: bool) -> None: x = torch.rand((17, 16), device=device).to(e4m3_type) y = torch.rand((16, 16), device=device).to(e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) input_bias = None if bias: input_bias = torch.rand((16,), device=device).to(torch.half) - out_fp8, amax_fp8 = torch._scaled_mm(x, y, bias=input_bias) + _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def test_float8_bias_relu_edgecase(self, device) -> None: @@ -485,19 +487,24 @@ def test_float8_bias_relu_edgecase(self, device) -> None: x = torch.full((k, l), 0.0, device=device).to(e4m3_type) y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() bias = torch.full((m,), -3.0, device=device, dtype=torch.half) - outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias) - self.assertEqual(amaxb_fp8.item(), 3.0) + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) + outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) + outb_fp32 = outb_fp8.to(torch.float32) + self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) def test_float32_output_errors_with_bias(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) self.assertRaisesRegex( RuntimeError, "Bias is not supported when out_dtype is set to Float32", - lambda: torch._scaled_mm(x, y, bias=bias, out_dtype=torch.float32), + lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), ) @unittest.skipIf(scaled_mm_supported_device(), @@ -506,10 +513,12 @@ def test_error_message_fp8_pre_sm89(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) y = torch.rand((m, l), device=device).to(e4m3_type).t() + scale_a = torch.tensor(1.0, device=device) + scale_b = torch.tensor(1.0, device=device) self.assertRaisesRegex( RuntimeError, r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", - lambda: torch._scaled_mm(x, y, out_dtype=torch.float32), + lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), ) @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) @@ -521,9 +530,9 @@ def test_float8_scale_fast_accum(self, device) -> None: y = torch.full(size, .5, device=device, dtype=y_type).t() scale_a = torch.tensor(1.5, device=device) scale_b = torch.tensor(0.66, device=device) - out_fp8, amax_fp8 = torch._scaled_mm(x, y, use_fast_accum=True) + out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) - out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) + out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 20b71d7f1e062d..e3ef4fc48bef05 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5326,11 +5326,11 @@ def meta__efficient_attention_backward( def meta_scaled_mm( self: torch.Tensor, mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - scale_a: Optional[torch.Tensor] = None, - scale_b: Optional[torch.Tensor] = None, scale_result: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): def is_row_major(stride): @@ -5372,9 +5372,7 @@ def is_fp8_type(dtype): lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) _out_dtype = out_dtype if out_dtype is not None else self.dtype - return torch.empty( - self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device - ), torch.empty((), dtype=torch.float32, device=self.device) + return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device) @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index c973f69cb69dd9..1eba22f85c9795 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -38,7 +38,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_a AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_efficient_attention_backward(AtenTensorHandle grad_out_, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double dropout_p, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, int32_t is_causal, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_dot_product_flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle* bias, int32_t* out_dtype, AtenTensorHandle* scale_a, AtenTensorHandle* scale_b, AtenTensorHandle* scale_result, int32_t use_fast_accum, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__thnn_fused_lstm_cell(AtenTensorHandle input_gates, AtenTensorHandle hidden_gates, AtenTensorHandle cx, AtenTensorHandle* input_bias, AtenTensorHandle* hidden_bias, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 1306c006ba94ee..9da7fa6e3b6218 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -621,17 +621,16 @@ AOTITorchError aoti_torch__scaled_mm( at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b); at::Tensor* scale_result_tensor = tensor_handle_to_tensor_pointer(scale_result); - auto [r0, r1] = at::_scaled_mm( + auto r0 = at::_scaled_mm( *self_tensor, *mat2_tensor, + *scale_a_tensor, + *scale_b_tensor, pointer_to_optional(bias_tensor), - pointer_to_optional(out_dtype), - pointer_to_optional(scale_a_tensor), - pointer_to_optional(scale_b_tensor), pointer_to_optional(scale_result_tensor), + pointer_to_optional(out_dtype), use_fast_accum); *ret0 = new_tensor_handle(std::move(r0)); - *ret1 = new_tensor_handle(std::move(r1)); }); } From c6b180a3166220ca7e505b891c79a67f53c23dce Mon Sep 17 00:00:00 2001 From: ibartol Date: Mon, 17 Jun 2024 16:50:37 +0000 Subject: [PATCH 088/171] Created docs (and example) for cudart function in torch.cuda (#128741) Fixes #127908 ## Description Created docs to document the torch.cuda.cudart function to solve the issue #127908. I tried to stick to the [guidelines to document a function](https://github.com/pytorch/pytorch/wiki/Docstring-Guidelines#documenting-a-function) but I was not sure if there is a consensus on how to handle the docs of a function that calls an internal function. So I went ahead and tried what the function will raise, etc. from the user endpoint and documented it (i.e. I am giving what actually _lazy_init() will raise). Updated PR from #128298 since I made quite a big mistake in my branch. I apologize for the newbie mistake. ### Summary of Changes - Added docs for torch.cuda.cudart - Added the cudart function in the autosummary of docs/source/cuda.rst ## Checklist - [X] The issue that is being fixed is referred in the description - [X] Only one issue is addressed in this pull request - [X] Labels from the issue that this PR is fixing are added to this pull request - [X] No unnecesary issues are included into this pull request Pull Request resolved: https://github.com/pytorch/pytorch/pull/128741 Approved by: https://github.com/msaroufim --- docs/source/cuda.rst | 1 + torch/cuda/__init__.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 7b9bf536c1453d..7f6f2d2f148b67 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -12,6 +12,7 @@ torch.cuda current_blas_handle current_device current_stream + cudart default_stream device device_count diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 6722114e295b88..e08572f5a09973 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -334,6 +334,59 @@ def _lazy_init(): def cudart(): + r"""Retrieves the CUDA runtime API module. + + + This function initializes the CUDA runtime environment if it is not already + initialized and returns the CUDA runtime API module (_cudart). The CUDA + runtime API module provides access to various CUDA runtime functions. + + Args: + ``None`` + + Returns: + module: The CUDA runtime API module (_cudart). + + Raises: + RuntimeError: If CUDA cannot be re-initialized in a forked subprocess. + AssertionError: If PyTorch is not compiled with CUDA support or if libcudart functions are unavailable. + + Example of CUDA operations with profiling: + >>> import torch + >>> from torch.cuda import cudart, check_error + >>> import os + >>> + >>> os.environ['CUDA_PROFILE'] = '1' + >>> + >>> def perform_cuda_operations_with_streams(): + >>> stream = torch.cuda.Stream() + >>> with torch.cuda.stream(stream): + >>> x = torch.randn(100, 100, device='cuda') + >>> y = torch.randn(100, 100, device='cuda') + >>> z = torch.mul(x, y) + >>> return z + >>> + >>> torch.cuda.synchronize() + >>> print("====== Start nsys profiling ======") + >>> check_error(cudart().cudaProfilerStart()) + >>> with torch.autograd.profiler.emit_nvtx(): + >>> result = perform_cuda_operations_with_streams() + >>> print("CUDA operations completed.") + >>> check_error(torch.cuda.cudart().cudaProfilerStop()) + >>> print("====== End nsys profiling ======") + + To run this example and save the profiling information, execute: + >>> $ nvprof --profile-from-start off --csv --print-summary -o trace_name.prof -f -- python cudart_test.py + + This command profiles the CUDA operations in the provided script and saves + the profiling information to a file named `trace_name.prof`. + The `--profile-from-start off` option ensures that profiling starts only + after the `cudaProfilerStart` call in the script. + The `--csv` and `--print-summary` options format the profiling output as a + CSV file and print a summary, respectively. + The `-o` option specifies the output file name, and the `-f` option forces the + overwrite of the output file if it already exists. + """ _lazy_init() return _cudart From 153362fbc9e8642fb851a4de3b99e3871a2cc714 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 17 Jun 2024 16:59:41 +0000 Subject: [PATCH 089/171] Support HSDP + Monolith Checkpointing (#128446) Fixes #128444. Rank 0 check should be in the same group as the broadcast Pull Request resolved: https://github.com/pytorch/pytorch/pull/128446 Approved by: https://github.com/fegin --- torch/distributed/fsdp/_optim_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index d4aa344c111419..54f800a168653a 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -341,14 +341,14 @@ def _broadcast_processed_state( group: Optional[dist.ProcessGroup], ) -> Dict[str, Any]: objects: List[Any] = [None] - if fsdp_state.rank == 0: + if dist.get_rank(group) == 0: objects[0] = tree_map_only( torch.Tensor, lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] optim_state, ) dist.broadcast_object_list(objects, src=0, group=group) - if fsdp_state.rank == 0: + if dist.get_rank(group) == 0: return optim_state else: return objects[0] @@ -357,7 +357,7 @@ def _broadcast_processed_state( def _broadcast_state( fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] ) -> Any: - if fsdp_state.rank == 0: + if dist.get_rank(group) == 0: if not isinstance(state, torch.Tensor) or state.dim() == 0: return state tensor = state.to(fsdp_state.compute_device) From d35cdee97f2838329567eab74dc717baa254205f Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 17 Jun 2024 18:17:58 +0000 Subject: [PATCH 090/171] [Caffe2] Remove caffe2 onnx tests (#128687) They are not used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128687 Approved by: https://github.com/r-barnes --- test/onnx/debug_embed_params.py | 54 - test/onnx/pytorch_helper.py | 91 - test/onnx_caffe2/export_onnx_tests_filter.py | 102 - .../export_onnx_tests_generator.py | 158 - test/onnx_caffe2/test_caffe2_common.py | 46 - test/onnx_caffe2/test_custom_ops.py | 60 - test/onnx_caffe2/test_pytorch_helper.py | 73 - test/onnx_caffe2/test_pytorch_onnx_caffe2.py | 3156 ----------------- .../test_pytorch_onnx_caffe2_quantized.py | 382 -- test/onnx_caffe2/test_verify.py | 106 - 10 files changed, 4228 deletions(-) delete mode 100644 test/onnx/debug_embed_params.py delete mode 100644 test/onnx/pytorch_helper.py delete mode 100644 test/onnx_caffe2/export_onnx_tests_filter.py delete mode 100644 test/onnx_caffe2/export_onnx_tests_generator.py delete mode 100644 test/onnx_caffe2/test_caffe2_common.py delete mode 100644 test/onnx_caffe2/test_custom_ops.py delete mode 100644 test/onnx_caffe2/test_pytorch_helper.py delete mode 100644 test/onnx_caffe2/test_pytorch_onnx_caffe2.py delete mode 100644 test/onnx_caffe2/test_pytorch_onnx_caffe2_quantized.py delete mode 100644 test/onnx_caffe2/test_verify.py diff --git a/test/onnx/debug_embed_params.py b/test/onnx/debug_embed_params.py deleted file mode 100644 index 8f32a838a99842..00000000000000 --- a/test/onnx/debug_embed_params.py +++ /dev/null @@ -1,54 +0,0 @@ -import sys - -import onnx -import pytorch_test_common - -import caffe2.python.onnx.backend as c2 -import torch -import torch.jit -from torch.autograd import Variable - -torch.set_default_tensor_type("torch.FloatTensor") -try: - import torch -except ImportError: - print("Cannot import torch, hence caffe2-torch test will not run.") - sys.exit(0) - - -def run_embed_params(proto, model, input, state_dict=None, use_gpu=True): - """ - This is only a helper debug function so we can test embed_params=False - case as well on pytorch front - This should likely be removed from the release version of the code - """ - device = "CPU" - if use_gpu: - device = "CUDA" - model_def = onnx.ModelProto.FromString(proto) - onnx.checker.check_model(model_def) - prepared = c2.prepare(model_def, device=device) - - if state_dict: - parameters = [] - # Passed in state_dict may have a different order. Make - # sure our order is consistent with the model's order. - # TODO: Even better: keyword arguments! - for k in model.state_dict(): - if k in state_dict: - parameters.append(state_dict[k]) - else: - parameters = list(model.state_dict().values()) - - W = {} - for k, v in zip( - model_def.graph.input, pytorch_test_common.flatten((input, parameters)) - ): - if isinstance(v, Variable): - W[k.name] = v.data.cpu().numpy() - else: - W[k.name] = v.cpu().numpy() - - caffe2_out = prepared.run(inputs=W) - - return caffe2_out diff --git a/test/onnx/pytorch_helper.py b/test/onnx/pytorch_helper.py deleted file mode 100644 index ff1c9faadeaa4d..00000000000000 --- a/test/onnx/pytorch_helper.py +++ /dev/null @@ -1,91 +0,0 @@ -import io - -import onnx - -import torch.onnx - -from caffe2.python.core import BlobReference, Net -from caffe2.python.onnx.backend import Caffe2Backend - -_next_idx = 0 -# Clone net takes a dict instead of a lambda -# It should probably take a lambda, it is more flexible -# We fake dict here - - -class _FakeDict: - def __init__(self, fn): - self.fn = fn - - def get(self, name, _): - return self.fn(name) - - -def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None): - """ - Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. - - Args: - helper (caffe2.python.core.ModelHelder): the model helper where - this imported network should be inserted - model (torch.nn.Module): the model to be exported - sample_arguments (tuple of arguments): the inputs to - the model, e.g., such that ``model(*args)`` is a valid - invocation of the model. Any non-Variable arguments will - be hard-coded into the exported model; any Variable arguments - will become inputs of the exported model, in the order they - occur in args. If args is a Variable, this is equivalent - to having called it with a 1-ary tuple of that Variable. - (Note: passing keyword arguments to the model is not currently - supported. Give us a shout if you need it.) - caffe2_inputs (list of str or caffe2.python.core.BlobReference): the - caffe2 Blobs that should be inputs to this network. Must be - the same length as sample_arguments - prefix_name: prefix name to add to each member of the blob, if None then - a fresh prefix pytorch_input_N/ is used - Returns: - A tuple of caffe2.python.core.BlobReference objects referring to the - models outputs, or a single BlobReference when the model returns a single - value. - """ - if prefix_name is None: - global _next_idx - prefix_name = "pytorch_import_" + str(_next_idx) + "/" - _next_idx += 1 - - # TODO: handle the case where model cannot be exported - # and embed as a Python op in Caffe2 - f = io.BytesIO() - torch.onnx.export(model, sample_arguments, f, export_params=True) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) - - initialized = {x.name for x in onnx_model.graph.initializer} - uninitialized_inputs = { - x.name: i - for i, x in enumerate(onnx_model.graph.input) - if x.name not in initialized - } - - if len(uninitialized_inputs) != len(caffe2_inputs): - raise ValueError( - f"Expected {len(uninitialized_inputs)} inputs but found {len(caffe2_inputs)}" - ) - - def remap_blob_name(name): - if name in uninitialized_inputs: - idx = uninitialized_inputs[name] - return str(caffe2_inputs[idx]) - return prefix_name + name - - predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name)) - helper.net.AppendNet(predict_net) - - init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name)) - helper.param_init_net.AppendNet(init_net) - - results = tuple( - BlobReference(remap_blob_name(x.name), helper.net) - for x in onnx_model.graph.output - ) - return results diff --git a/test/onnx_caffe2/export_onnx_tests_filter.py b/test/onnx_caffe2/export_onnx_tests_filter.py deleted file mode 100644 index 868f72fddc342d..00000000000000 --- a/test/onnx_caffe2/export_onnx_tests_filter.py +++ /dev/null @@ -1,102 +0,0 @@ -import argparse -import glob -import os -import shutil -import traceback - -import google.protobuf.text_format -import onnx.backend.test -import onnx_test_common -from test_caffe2_common import run_generated_test - -from torch.testing._internal.common_device_type import get_all_device_types - -_fail_test_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "fail", "generated" -) - - -_expect_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "expect") - - -def collect_generated_testcases( - root_dir=onnx_test_common.pytorch_converted_dir, - verbose=False, - fail_dir=None, - expect=True, -): - total_pass = 0 - total_fail = 0 - for d in os.listdir(root_dir): - dir_name = os.path.join(root_dir, d) - if os.path.isdir(dir_name): - failed = False - try: - model_file = os.path.join(dir_name, "model.onnx") - data_dir_pattern = os.path.join(dir_name, "test_data_set_*") - for data_dir in glob.glob(data_dir_pattern): - for device in get_all_device_types(): - run_generated_test(model_file, data_dir, device) - if expect: - expect_file = os.path.join( - _expect_dir, f"PyTorch-generated-{d}.expect" - ) - with open(expect_file, "w") as text_file: - model = onnx.load(model_file) - onnx.checker.check_model(model) - onnx.helper.strip_doc_string(model) - text_file.write( - google.protobuf.text_format.MessageToString(model) - ) - total_pass += 1 - except Exception as e: - if verbose: - print(f"The test case in {dir_name} failed!") - traceback.print_exc() - if fail_dir is None: - shutil.rmtree(dir_name) - else: - target_dir = os.path.join(fail_dir, d) - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - shutil.move(dir_name, target_dir) - total_fail += 1 - print(f"Successfully generated/updated {total_pass} test cases from PyTorch.") - if expect: - print(f"Expected pbtxt files are generated in {_expect_dir}.") - print(f"Failed {total_fail} testcases are moved to {_fail_test_dir}.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Check and filter the failed test cases." - ) - parser.add_argument("-v", action="store_true", default=False, help="verbose") - parser.add_argument( - "--delete", action="store_true", default=False, help="delete failed test cases" - ) - parser.add_argument( - "--no-expect", - action="store_true", - default=False, - help="generate expect txt files", - ) - args = parser.parse_args() - verbose = args.v - delete = args.delete - expect = not args.no_expect - fail_dir = _fail_test_dir - if delete: - fail_dir = None - if fail_dir: - if not os.path.exists(fail_dir): - os.makedirs(fail_dir) - - collect_generated_testcases(verbose=verbose, fail_dir=fail_dir, expect=expect) - # We already generate the expect files for test_operators.py. - collect_generated_testcases( - root_dir=onnx_test_common.pytorch_operator_dir, - verbose=verbose, - fail_dir=fail_dir, - expect=False, - ) diff --git a/test/onnx_caffe2/export_onnx_tests_generator.py b/test/onnx_caffe2/export_onnx_tests_generator.py deleted file mode 100644 index 43ea9a22a60b81..00000000000000 --- a/test/onnx_caffe2/export_onnx_tests_generator.py +++ /dev/null @@ -1,158 +0,0 @@ -import io -import os -import shutil -import traceback - -import onnx -import onnx_test_common -from onnx import numpy_helper -from test_nn import new_module_tests - -import torch -from torch.autograd import Variable -from torch.testing._internal.common_nn import module_tests - - -# Take a test case (a dict) as input, return the test name. -def get_test_name(testcase): - if "fullname" in testcase: - return "test_" + testcase["fullname"] - - test_name = "test_" + testcase["constructor"].__name__ - if "desc" in testcase: - test_name += "_" + testcase["desc"] - return test_name - - -# Take a test case (a dict) as input, return the input for the module. -def gen_input(testcase): - if "input_size" in testcase: - if ( - testcase["input_size"] == () - and "desc" in testcase - and testcase["desc"][-6:] == "scalar" - ): - testcase["input_size"] = (1,) - return Variable(torch.randn(*testcase["input_size"])) - elif "input_fn" in testcase: - input = testcase["input_fn"]() - if isinstance(input, Variable): - return input - return Variable(testcase["input_fn"]()) - - -def gen_module(testcase): - if "constructor_args" in testcase: - args = testcase["constructor_args"] - module = testcase["constructor"](*args) - module.train(False) - return module - module = testcase["constructor"]() - module.train(False) - return module - - -def print_stats(FunctionalModule_nums, nn_module): - print(f"{FunctionalModule_nums} functional modules detected.") - supported = [] - unsupported = [] - not_fully_supported = [] - for key, value in nn_module.items(): - if value == 1: - supported.append(key) - elif value == 2: - unsupported.append(key) - elif value == 3: - not_fully_supported.append(key) - - def fun(info, l): - print(info) - for v in l: - print(v) - - # Fully Supported Ops: All related test cases of these ops have been exported - # Semi-Supported Ops: Part of related test cases of these ops have been exported - # Unsupported Ops: None of related test cases of these ops have been exported - for info, l in [ - [f"{len(supported)} Fully Supported Operators:", supported], - [ - f"{len(not_fully_supported)} Semi-Supported Operators:", - not_fully_supported, - ], - [f"{len(unsupported)} Unsupported Operators:", unsupported], - ]: - fun(info, l) - - -def convert_tests(testcases, sets=1): - print(f"Collect {len(testcases)} test cases from PyTorch.") - failed = 0 - FunctionalModule_nums = 0 - nn_module = {} - for t in testcases: - test_name = get_test_name(t) - module = gen_module(t) - module_name = str(module).split("(")[0] - if module_name == "FunctionalModule": - FunctionalModule_nums += 1 - else: - if module_name not in nn_module: - nn_module[module_name] = 0 - try: - input = gen_input(t) - f = io.BytesIO() - torch.onnx._export( - module, - input, - f, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - onnx_model = onnx.load_from_string(f.getvalue()) - onnx.checker.check_model(onnx_model) - onnx.helper.strip_doc_string(onnx_model) - output_dir = os.path.join(onnx_test_common.pytorch_converted_dir, test_name) - - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(output_dir) - with open(os.path.join(output_dir, "model.onnx"), "wb") as file: - file.write(onnx_model.SerializeToString()) - - for i in range(sets): - output = module(input) - data_dir = os.path.join(output_dir, f"test_data_set_{i}") - os.makedirs(data_dir) - - for index, var in enumerate([input]): - tensor = numpy_helper.from_array(var.data.numpy()) - with open( - os.path.join(data_dir, f"input_{index}.pb"), "wb" - ) as file: - file.write(tensor.SerializeToString()) - for index, var in enumerate([output]): - tensor = numpy_helper.from_array(var.data.numpy()) - with open( - os.path.join(data_dir, f"output_{index}.pb"), "wb" - ) as file: - file.write(tensor.SerializeToString()) - input = gen_input(t) - if module_name != "FunctionalModule": - nn_module[module_name] |= 1 - except: # noqa: E722,B001 - traceback.print_exc() - if module_name != "FunctionalModule": - nn_module[module_name] |= 2 - failed += 1 - - print( - f"Collect {len(testcases)} test cases from PyTorch repo, failed to export {failed} cases." - ) - print( - f"PyTorch converted cases are stored in {onnx_test_common.pytorch_converted_dir}." - ) - print_stats(FunctionalModule_nums, nn_module) - - -if __name__ == "__main__": - testcases = module_tests + new_module_tests - convert_tests(testcases) diff --git a/test/onnx_caffe2/test_caffe2_common.py b/test/onnx_caffe2/test_caffe2_common.py deleted file mode 100644 index e85f4b8aef3657..00000000000000 --- a/test/onnx_caffe2/test_caffe2_common.py +++ /dev/null @@ -1,46 +0,0 @@ -# Owner(s): ["module: onnx"] - -import glob -import os - -import numpy as np -import onnx.backend.test -from onnx import numpy_helper - -import caffe2.python.onnx.backend as c2 - - -def load_tensor_as_numpy_array(f): - tensor = onnx.TensorProto() - with open(f, "rb") as file: - tensor.ParseFromString(file.read()) - return tensor - - -def assert_similar(ref, real): - np.testing.assert_equal(len(ref), len(real)) - for i in range(len(ref)): - np.testing.assert_allclose(ref[i], real[i], rtol=1e-3) - - -def run_generated_test(model_file, data_dir, device="CPU"): - model = onnx.load(model_file) - input_num = len(glob.glob(os.path.join(data_dir, "input_*.pb"))) - inputs = [] - for i in range(input_num): - inputs.append( - numpy_helper.to_array( - load_tensor_as_numpy_array(os.path.join(data_dir, f"input_{i}.pb")) - ) - ) - output_num = len(glob.glob(os.path.join(data_dir, "output_*.pb"))) - outputs = [] - for i in range(output_num): - outputs.append( - numpy_helper.to_array( - load_tensor_as_numpy_array(os.path.join(data_dir, f"output_{i}.pb")) - ) - ) - prepared = c2.prepare(model, device=device) - c2_outputs = prepared.run(inputs) - assert_similar(outputs, c2_outputs) diff --git a/test/onnx_caffe2/test_custom_ops.py b/test/onnx_caffe2/test_custom_ops.py deleted file mode 100644 index f25ae1b43a8478..00000000000000 --- a/test/onnx_caffe2/test_custom_ops.py +++ /dev/null @@ -1,60 +0,0 @@ -# Owner(s): ["module: onnx"] - -import numpy as np -import onnx -import pytorch_test_common -from test_pytorch_onnx_caffe2 import do_export - -import caffe2.python.onnx.backend as c2 -import torch -import torch.utils.cpp_extension -from torch.testing._internal import common_utils - - -class TestCaffe2CustomOps(pytorch_test_common.ExportTestCase): - def test_custom_add(self): - op_source = """ - #include - - torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) { - return self + other; - } - - static auto registry = - torch::RegisterOperators("custom_namespace::custom_add", &custom_add); - """ - - torch.utils.cpp_extension.load_inline( - name="custom_add", - cpp_sources=op_source, - is_python_module=False, - verbose=True, - ) - - class CustomAddModel(torch.nn.Module): - def forward(self, a, b): - return torch.ops.custom_namespace.custom_add(a, b) - - def symbolic_custom_add(g, self, other): - return g.op("Add", self, other) - - torch.onnx.register_custom_op_symbolic( - "custom_namespace::custom_add", symbolic_custom_add, 9 - ) - - x = torch.randn(2, 3, 4, requires_grad=False) - y = torch.randn(2, 3, 4, requires_grad=False) - - model = CustomAddModel() - # before fixing #51833 this used to give a PyBind error - # with PyTorch 1.10dev ("Unable to cast from non-held to held - # instance (T& to Holder)") - onnxir, _ = do_export(model, (x, y), opset_version=11) - onnx_model = onnx.ModelProto.FromString(onnxir) - prepared = c2.prepare(onnx_model) - caffe2_out = prepared.run(inputs=[x.cpu().numpy(), y.cpu().numpy()]) - np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy()) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx_caffe2/test_pytorch_helper.py b/test/onnx_caffe2/test_pytorch_helper.py deleted file mode 100644 index 56a4932a999629..00000000000000 --- a/test/onnx_caffe2/test_pytorch_helper.py +++ /dev/null @@ -1,73 +0,0 @@ -# Owner(s): ["module: onnx"] - -# Some standard imports -import unittest - -import numpy as np -import pytorch_test_common -from pytorch_helper import PyTorchModule - -import torch.nn.init as init -import torch.onnx - -from caffe2.python.core import workspace -from caffe2.python.model_helper import ModelHelper -from torch import nn -from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNoLapack - - -class TestCaffe2Backend(pytorch_test_common.ExportTestCase): - @skipIfNoLapack - @unittest.skip("test broken because Lapack was always missing.") - def test_helper(self): - class SuperResolutionNet(nn.Module): - def __init__(self, upscale_factor, inplace=False): - super().__init__() - - self.relu = nn.ReLU(inplace=inplace) - self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) - self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) - self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) - self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) - self.pixel_shuffle = nn.PixelShuffle(upscale_factor) - - self._initialize_weights() - - def forward(self, x): - x = self.relu(self.conv1(x)) - x = self.relu(self.conv2(x)) - x = self.relu(self.conv3(x)) - x = self.pixel_shuffle(self.conv4(x)) - return x - - def _initialize_weights(self): - init.orthogonal(self.conv1.weight, init.calculate_gain("relu")) - init.orthogonal(self.conv2.weight, init.calculate_gain("relu")) - init.orthogonal(self.conv3.weight, init.calculate_gain("relu")) - init.orthogonal(self.conv4.weight) - - torch_model = SuperResolutionNet(upscale_factor=3) - - fake_input = torch.randn(1, 1, 224, 224, requires_grad=True) - - # use ModelHelper to create a C2 net - helper = ModelHelper(name="test_model") - start = helper.Sigmoid(["the_input"]) - # Embed the ONNX-converted pytorch net inside it - (toutput,) = PyTorchModule(helper, torch_model, (fake_input,), [start]) - output = helper.Sigmoid(toutput) - - workspace.RunNetOnce(helper.InitProto()) - workspace.FeedBlob("the_input", fake_input.data.numpy()) - # print([ k for k in workspace.blobs ]) - workspace.RunNetOnce(helper.Proto()) - c2_out = workspace.FetchBlob(str(output)) - - torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input))) - - np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx_caffe2/test_pytorch_onnx_caffe2.py b/test/onnx_caffe2/test_pytorch_onnx_caffe2.py deleted file mode 100644 index 18cba5b73f7664..00000000000000 --- a/test/onnx_caffe2/test_pytorch_onnx_caffe2.py +++ /dev/null @@ -1,3156 +0,0 @@ -# Owner(s): ["module: onnx"] - -import io -import itertools -import sys -import unittest -from typing import Tuple - -import model_defs.dcgan as dcgan -import model_defs.word_language_model as word_language_model -import numpy as np -import onnx -import pytorch_test_common -import verify -from debug_embed_params import run_embed_params -from model_defs.lstm_flattening_result import LstmFlatteningResult -from model_defs.mnist import MNIST -from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence -from model_defs.squeezenet import SqueezeNet -from model_defs.srresnet import SRResNet -from model_defs.super_resolution import SuperResolutionNet -from pytorch_test_common import ( - BATCH_SIZE, - RNN_BATCH_SIZE, - RNN_HIDDEN_SIZE, - RNN_INPUT_SIZE, - RNN_SEQUENCE_LENGTH, - skipIfNoCuda, - skipIfTravis, - skipIfUnsupportedMinOpsetVersion, - skipIfUnsupportedOpsetVersion, -) - -# Import various models for testing -from torchvision.models.alexnet import alexnet -from torchvision.models.densenet import densenet121 -from torchvision.models.inception import inception_v3 -from torchvision.models.resnet import resnet50 -from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn - -import caffe2.python.onnx.backend as c2 -import torch.onnx -import torch.onnx.operators -import torch.utils.model_zoo as model_zoo -from caffe2.python.operator_test.torch_integration_test import ( - create_bbox_transform_inputs, - generate_rois_rotated, -) -from torch import nn -from torch.autograd import function, Variable -from torch.nn.utils import rnn as rnn_utils -from torch.onnx import ExportTypes -from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNoLapack - -skip = unittest.skip - - -def skipIfEmbed(func): - def wrapper(self): - if self.embed_params: - raise unittest.SkipTest("Skip embed_params verify test") - return func(self) - - return wrapper - - -def skipIfNoEmbed(func): - def wrapper(self): - if not self.embed_params: - raise unittest.SkipTest("Skip debug embed_params test") - return func(self) - - return wrapper - - -# def import_model(proto, input, workspace=None, use_gpu=True): -# model_def = onnx.ModelProto.FromString(proto) -# onnx.checker.check_model(model_def) -# -# if workspace is None: -# workspace = {} -# if isinstance(input, tuple): -# for i in range(len(input)): -# workspace[model_def.graph.input[i]] = input[i] -# else: -# workspace[model_def.graph.input[0]] = input -# -# caffe2_out_workspace = c2.run_model( -# init_graph=None, -# predict_graph=graph_def, -# inputs=workspace, -# use_gpu=use_gpu) -# caffe2_out = caffe2_out_workspace[0] -# return caffe2_out - - -def do_export(model, inputs, *args, **kwargs): - f = io.BytesIO() - out = torch.onnx._export(model, inputs, f, *args, **kwargs) - if isinstance(model, torch.jit.ScriptModule): - # Special case for common case of passing a single Tensor - if isinstance(inputs, torch.Tensor): - inputs = (inputs,) - out = model(*inputs) - return f.getvalue(), out - - -torch.set_default_tensor_type("torch.FloatTensor") -try: - import torch -except ImportError: - print("Cannot import torch, hence caffe2-torch test will not run.") - sys.exit(0) - - -model_urls = { - "alexnet": "https://s3.amazonaws.com/download.caffe2.ai/test_data/alexnet-owt-4df8aa71.pth", - "dcgan_b": "https://s3.amazonaws.com/pytorch/test_data/export/netG_bedroom_epoch_1-0649e76b.pth", - "dcgan_f": "https://s3.amazonaws.com/pytorch/test_data/export/netG_faces_epoch_49-d86035a6.pth", - "densenet121": "https://s3.amazonaws.com/download.caffe2.ai/test_data/densenet121-d66d3027.pth", - "inception_v3_google": "https://s3.amazonaws.com/download.caffe2.ai/test_data/inception_v3_google-1a9a5a14.pth", - "resnet50": "https://s3.amazonaws.com/download.caffe2.ai/test_data/resnet50-19c8e357.pth", - "srresNet": "https://s3.amazonaws.com/pytorch/demos/srresnet-e10b2039.pth", - "super_resolution": "https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth", - "squeezenet1_0": "https://s3.amazonaws.com/download.caffe2.ai/test_data/squeezenet1_0-a815701f.pth", - "squeezenet1_1": "https://s3.amazonaws.com/download.caffe2.ai/test_data/squeezenet1_1-f364aa15.pth", - "vgg16": "https://s3.amazonaws.com/download.caffe2.ai/test_data/vgg16-397923af.pth", - "vgg19": "https://s3.amazonaws.com/download.caffe2.ai/test_data/vgg19-dcbb9e9d.pth", -} - - -class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase): - opset_version = 9 - embed_params = False - - def convert_cuda(self, model, input): - cuda_model = model.cuda() - # input might be nested - we want to move everything to GPU - cuda_input = function._nested_map( - lambda o: isinstance(o, (Variable, torch.Tensor)), - lambda o: o.cuda(), - )(input) - return cuda_model, cuda_input - - def run_debug_test( - self, - model, - train, - batch_size, - state_dict=None, - input=None, - use_gpu=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX, - ): - """ - # TODO: remove this from the final release version - This test is for our debugging only for the case where - embed_params=False - """ - if not isinstance(model, torch.jit.ScriptModule): - model.train(train) - if state_dict is not None: - model.load_state_dict(state_dict) - - # Either user specified input or random (deterministic) input - if input is None: - input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) - if use_gpu: - model, input = self.convert_cuda(model, input) - - onnxir, torch_out = do_export( - model, - input, - export_params=self.embed_params, - verbose=False, - do_constant_folding=False, - opset_version=self.opset_version, - keep_initializers_as_inputs=True, - add_node_names=False, - operator_export_type=operator_export_type, - ) - if isinstance(torch_out, torch.autograd.Variable): - torch_out = (torch_out,) - - caffe2_out = run_embed_params(onnxir, model, input, state_dict, use_gpu) - for _, (x, y) in enumerate(zip(torch_out, caffe2_out)): - np.testing.assert_almost_equal(x.data.cpu().numpy(), y, decimal=3) - - def run_actual_test( - self, - model, - train, - batch_size, - state_dict=None, - input=None, - use_gpu=True, - rtol=0.001, - atol=1e-7, - do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX, - input_names=None, - dynamic_axes=None, - remained_onnx_input_idx=None, - ): - """ - This is what the user facing version will look like - """ - # set the training/test mode for the model - if not isinstance(model, torch.jit.ScriptModule): - model.train(train) - # use the pre-trained model params if available - if state_dict is not None: - model.load_state_dict(state_dict) - - # Either user specified input or random (deterministic) input - if input is None: - input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) - # GPU-ize the model, if requested - if use_gpu: - model, input = self.convert_cuda(model, input) - - # Verify the model runs the same in Caffe2 - verify.verify( - model, - input, - c2, - rtol=rtol, - atol=atol, - do_constant_folding=do_constant_folding, - opset_version=self.opset_version, - keep_initializers_as_inputs=True, - operator_export_type=operator_export_type, - input_names=input_names, - dynamic_axes=dynamic_axes, - remained_onnx_input_idx=remained_onnx_input_idx, - ) - - def run_model_test( - self, - model, - train, - batch_size, - state_dict=None, - input=None, - use_gpu=True, - rtol=0.001, - atol=1e-7, - do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX, - input_names=None, - dynamic_axes=None, - remained_onnx_input_idx=None, - ): - use_gpu_ = torch.cuda.is_available() and use_gpu - # NOTE: do_constant_folding is turned on only when model has - # parameters embedded (which are needed for constant folding), - # i.e. for self.embed_params=True case. self.embed_params=True - # for the TestCaffe2BackendEmbed class defined at the bottom. - if self.embed_params: - self.run_actual_test( - model, - train, - batch_size, - state_dict, - input, - use_gpu=use_gpu_, - rtol=rtol, - atol=atol, - do_constant_folding=do_constant_folding, - operator_export_type=operator_export_type, - input_names=input_names, - dynamic_axes=dynamic_axes, - remained_onnx_input_idx=remained_onnx_input_idx, - ) - else: - self.run_debug_test( - model, - train, - batch_size, - state_dict, - input, - use_gpu=use_gpu_, - operator_export_type=operator_export_type, - ) - - def test_linear(self): - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.many_fc = nn.Sequential( - nn.Linear(4, 5, bias=True), - nn.ReLU(inplace=True), - nn.Linear(5, 6, bias=True), - nn.ReLU(inplace=True), - nn.Linear(6, 7, bias=True), - ) - - def forward(self, input): - return self.many_fc(input) - - model = MyModel() - input = torch.randn(3, 4, requires_grad=True) - self.run_model_test(model, train=False, batch_size=0, input=input) - - def test_onnx_export_with_parameter_renaming(self): - class SimpleFcNet(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(5, 10) - - def forward(self, input): - return self.fc1(input) - - model = SimpleFcNet() - input = torch.randn(7, 5) - output = model(input) - - f = io.BytesIO() - # Note that the export call explicitly sets the names of not just the input, - # but also the parameters. This test checks that the model can be loaded and - # executed in Caffe2 backend correctly. - torch.onnx._export( - model, - input, - f, - verbose=True, - export_type=ExportTypes.ZIP_ARCHIVE, - input_names=["input1", "parameter1", "parameter2"], - keep_initializers_as_inputs=True, - ) - - f.seek(0) - model_c2 = c2.prepare_zip_archive(f) - result = model_c2.run(input.numpy()) - np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3) - - def test_onnx_export_param_name_duplication(self): - class SimpleFcNet(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(5, 10) - - def forward(self, input): - return self.fc1(input) - - model = SimpleFcNet() - input = torch.randn(7, 5) - output = model(input) - - f = io.BytesIO() - # The export call explicitly sets the names of the input, and the first parameter. - # But note that the target first parameter name is the same as the second parameter name. - # This test checks that given this edge condition, the model can be loaded and executed - # in Caffe2 backend correctly. - torch.onnx._export( - model, - input, - f, - verbose=True, - export_type=ExportTypes.ZIP_ARCHIVE, - input_names=["input1", "fc1.bias"], - keep_initializers_as_inputs=True, - ) - - f.seek(0) - model_c2 = c2.prepare_zip_archive(f) - result = model_c2.run(input.numpy()) - np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3) - - def test_lstm_cell(self): - model = nn.LSTMCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE) - input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) - h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) - c0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) - self.run_model_test( - model, - train=False, - batch_size=BATCH_SIZE, - input=(input, (h0, c0)), - use_gpu=False, - ) - - def test_gru_cell(self): - model = nn.GRUCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE) - input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) - h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) - self.run_model_test( - model, train=False, batch_size=BATCH_SIZE, input=(input, h0), use_gpu=False - ) - - def _dispatch_rnn_test(self, name, *args, **kwargs): - if name == "elman": - self._elman_rnn_test(*args, **kwargs) - if name == "lstm": - self._lstm_test(*args, **kwargs) - if name == "gru": - self._gru_test(*args, **kwargs) - - def _elman_rnn_test( - self, - layers, - nonlinearity, - bidirectional, - initial_state, - packed_sequence, - dropout, - ): - batch_first = True if packed_sequence == 2 else False - model = nn.RNN( - RNN_INPUT_SIZE, - RNN_HIDDEN_SIZE, - layers, - nonlinearity=nonlinearity, - bidirectional=bidirectional, - dropout=dropout, - batch_first=batch_first, - ) - - if packed_sequence == 1: - model = RnnModelWithPackedSequence(model, False) - if packed_sequence == 2: - model = RnnModelWithPackedSequence(model, True) - - def make_input(batch_size): - seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) - seq_lengths = sorted(map(int, seq_lengths), reverse=True) - inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] - inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) - inputs = [inputs] - - directions = 2 if bidirectional else 1 - - if initial_state: - h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) - inputs.append(h0) - if packed_sequence != 0: - inputs.append(torch.IntTensor(seq_lengths)) - if len(inputs) == 1: - input = inputs[0] - else: - input = tuple(inputs) - return input - - input = make_input(RNN_BATCH_SIZE) - self.run_model_test( - model, - train=False, - batch_size=RNN_BATCH_SIZE, - input=input, - use_gpu=False, - atol=1e-7, - ) - - # test that the model still runs with a different batch size - # (save the model with a batch_size of 1 with rnn with a variable batch size, - # otherwise expand will fail) - variable_batch_size_init_input = make_input(1) - # Constant folding works when model has parameters embedded. For this case, we need to disable it - onnxir, _ = do_export( - model, - variable_batch_size_init_input, - keep_initializers_as_inputs=True, - do_constant_folding=False, - ) - other_input = make_input(RNN_BATCH_SIZE + 1) - _ = run_embed_params(onnxir, model, other_input, use_gpu=False) - - def _lstm_test( - self, layers, bidirectional, initial_state, packed_sequence, dropout - ): - batch_first = True if packed_sequence == 2 else False - model = LstmFlatteningResult( - RNN_INPUT_SIZE, - RNN_HIDDEN_SIZE, - layers, - bidirectional=bidirectional, - dropout=dropout, - batch_first=batch_first, - ) - if packed_sequence == 1: - model = RnnModelWithPackedSequence(model, False) - if packed_sequence == 2: - model = RnnModelWithPackedSequence(model, True) - - def make_input(batch_size): - seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) - seq_lengths = sorted(map(int, seq_lengths), reverse=True) - inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] - inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) - inputs = [inputs] - - directions = 2 if bidirectional else 1 - - if initial_state: - h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) - c0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) - inputs.append((h0, c0)) - if packed_sequence != 0: - inputs.append(torch.IntTensor(seq_lengths)) - if len(inputs) == 1: - input = inputs[0] - else: - input = tuple(inputs) - return input - - input = make_input(RNN_BATCH_SIZE) - self.run_model_test( - model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False - ) - - # test that the model still runs with a different batch size - # (save the model with a batch_size of 1 with rnn with a variable batch size, - # otherwise expand will fail) - variable_batch_size_init_input = make_input(1) - # Constant folding works when model has parameters embedded. For this case, we need to disable it - onnxir, _ = do_export( - model, - variable_batch_size_init_input, - keep_initializers_as_inputs=True, - do_constant_folding=False, - ) - other_input = make_input(RNN_BATCH_SIZE + 1) - _ = run_embed_params(onnxir, model, other_input, use_gpu=False) - - def _gru_test(self, layers, bidirectional, initial_state, packed_sequence, dropout): - batch_first = True if packed_sequence == 2 else False - model = nn.GRU( - RNN_INPUT_SIZE, - RNN_HIDDEN_SIZE, - layers, - bidirectional=bidirectional, - dropout=dropout, - batch_first=batch_first, - ) - if packed_sequence == 1: - model = RnnModelWithPackedSequence(model, False) - if packed_sequence == 2: - model = RnnModelWithPackedSequence(model, True) - - def make_input(batch_size): - seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size) - seq_lengths = sorted(map(int, seq_lengths), reverse=True) - inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] - inputs = rnn_utils.pad_sequence(inputs, batch_first=batch_first) - inputs = [inputs] - - directions = 2 if bidirectional else 1 - - if initial_state: - h0 = torch.randn(directions * layers, batch_size, RNN_HIDDEN_SIZE) - inputs.append(h0) - if packed_sequence != 0: - inputs.append(torch.IntTensor(seq_lengths)) - if len(inputs) == 1: - input = inputs[0] - else: - input = tuple(inputs) - return input - - input = make_input(RNN_BATCH_SIZE) - self.run_model_test( - model, train=False, batch_size=RNN_BATCH_SIZE, input=input, use_gpu=False - ) - - # test that the model still runs with a different batch size - # (save the model with a batch_size of 1 with rnn with a variable batch size, - # otherwise expand will fail) - variable_batch_size_init_input = make_input(1) - # Constant folding works when model has parameters embedded. For this case, we need to disable it - onnxir, _ = do_export( - model, - variable_batch_size_init_input, - keep_initializers_as_inputs=True, - do_constant_folding=False, - ) - other_input = make_input(RNN_BATCH_SIZE + 1) - _ = run_embed_params(onnxir, model, other_input, use_gpu=False) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - def test_rnn_init_predict_split(self): - model = nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 3, bidirectional=True) - seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=7) - seq_lengths = sorted(map(int, seq_lengths), reverse=True) - input = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths] - input = rnn_utils.pad_sequence(input) - - # Test that we are correctly splitting between init and - # predict net. When we embed parameters, there should be more - # ops in the init net. - mp = onnx.ModelProto.FromString( - do_export( - model, - input, - export_params=self.embed_params, - keep_initializers_as_inputs=True, - do_constant_folding=False, - )[0] - ) - prepared = c2.prepare(mp, device="CPU") - if self.embed_params: - assert len(prepared.init_net.op) == 950 - assert len(prepared.predict_net.op) == 101 - else: - assert len(prepared.init_net.op) == 83 - assert len(prepared.predict_net.op) == 968 - - def test_alexnet(self): - state_dict = model_zoo.load_url(model_urls["alexnet"], progress=False) - self.run_model_test( - alexnet(), - train=False, - batch_size=BATCH_SIZE, - state_dict=state_dict, - atol=1e-3, - ) - - @skipIfNoCuda - def test_dcgan(self): - # dcgan is flaky on some seeds, see: - # https://github.com/ProjectToffee/onnx/pull/70 - torch.manual_seed(1) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(1) - - netD = dcgan._netD(1) - netD.apply(dcgan.weights_init) - input = torch.randn(BATCH_SIZE, 3, dcgan.imgsz, dcgan.imgsz) - self.run_model_test(netD, train=False, batch_size=BATCH_SIZE, input=input) - - netG = dcgan._netG(1) - netG.apply(dcgan.weights_init) - state_dict = model_zoo.load_url(model_urls["dcgan_b"], progress=False) - # state_dict = model_zoo.load_url(model_urls["dcgan_f"], progress=False) - noise = torch.randn(BATCH_SIZE, dcgan.nz, 1, 1).normal_(0, 1) - self.run_model_test( - netG, - train=False, - batch_size=BATCH_SIZE, - input=noise, - state_dict=state_dict, - rtol=1e-2, - atol=1e-6, - ) - - @unittest.skipIf( - not torch.cuda.is_available(), "model on net has cuda in it, awaiting fix" - ) - def test_densenet(self): - state_dict = model_zoo.load_url(model_urls["densenet121"], progress=False) - self.run_model_test( - densenet121(), - train=False, - batch_size=BATCH_SIZE, - state_dict=state_dict, - atol=1e-7, - ) - - @skip("doesn't match exactly...") - # TODO: figure out the numerical instabilities - def test_inception(self): - x = torch.randn(BATCH_SIZE, 3, 299, 299, requires_grad=True) - # state_dict = model_zoo.load_url(model_urls["inception_v3_google"], progress=False) - state_dict = None - self.run_model_test( - inception_v3(), - train=False, - batch_size=BATCH_SIZE, - state_dict=state_dict, - input=x, - ) - - @skipIfNoEmbed - def test_resnet(self): - state_dict = model_zoo.load_url(model_urls["resnet50"], progress=False) - self.run_model_test( - resnet50(), - train=False, - batch_size=BATCH_SIZE, - state_dict=state_dict, - atol=1e-5, - ) - - def test_squeezenet(self): - sqnet_v1_1 = SqueezeNet(version=1.1) - state_dict = model_zoo.load_url(model_urls["squeezenet1_1"], progress=False) - # state_dict = model_zoo.load_url(model_urls["squeezenet1_0"], progress=False) - self.run_model_test( - sqnet_v1_1, train=False, batch_size=BATCH_SIZE, state_dict=state_dict - ) - - # @skip("takes long to run, LAPACK needed for gpu") - @skipIfNoLapack - @unittest.skip("This model takes too much memory") - def test_srresnet(self): - super_resolution_net = SRResNet(rescale_factor=4, n_filters=64, n_blocks=8) - state_dict = model_zoo.load_url(model_urls["srresNet"], progress=False) - x = torch.randn(1, 3, 224, 224, requires_grad=True) - self.run_model_test( - super_resolution_net, - train=False, - batch_size=1, - state_dict=state_dict, - input=x, - use_gpu=False, - ) - - @skipIfTravis - @skipIfNoLapack - @skipIfNoCuda - def test_super_resolution(self): - super_resolution_net = SuperResolutionNet(upscale_factor=3) - state_dict = model_zoo.load_url(model_urls["super_resolution"], progress=False) - x = torch.randn(1, 1, 224, 224, requires_grad=True) - self.run_model_test( - super_resolution_net, - train=False, - batch_size=BATCH_SIZE, - state_dict=state_dict, - input=x, - use_gpu=False, - atol=1e-6, - ) - - @unittest.skip("This model takes too much memory") - def test_vgg16(self): - state_dict = model_zoo.load_url(model_urls["vgg16"], progress=False) - self.run_model_test( - vgg16(), train=False, batch_size=BATCH_SIZE, state_dict=state_dict - ) - - @skip("disable to run tests faster...") - def test_vgg16_bn(self): - self.run_model_test(vgg16_bn(), train=False, batch_size=BATCH_SIZE) - - @skip("disable to run tests faster...") - def test_vgg19(self): - state_dict = model_zoo.load_url(model_urls["vgg19"], progress=False) - self.run_model_test( - vgg19(), train=False, batch_size=BATCH_SIZE, state_dict=state_dict - ) - - @skip("disable to run tests faster...") - def test_vgg19_bn(self): - self.run_model_test(vgg19_bn(), train=False, batch_size=BATCH_SIZE) - - def run_word_language_model(self, model_name): - ntokens = 50 - emsize = 5 - nhid = 5 - nlayers = 5 - dropout = 0.2 - tied = False - batchsize = 5 - model = word_language_model.RNNModel( - model_name, ntokens, emsize, nhid, nlayers, dropout, tied, batchsize - ) - x = torch.arange(0, ntokens).long().view(-1, batchsize) - # Only support CPU version, since tracer is not working in GPU RNN. - self.run_model_test( - model, - train=False, - input=(x, model.hidden), - batch_size=batchsize, - use_gpu=False, - ) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_word_language_model_RNN_TANH(self): - self.run_word_language_model("RNN_TANH") - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_word_language_model_RNN_RELU(self): - self.run_word_language_model("RNN_RELU") - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_word_language_model_LSTM(self): - self.run_word_language_model("LSTM") - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_word_language_model_GRU(self): - self.run_word_language_model("GRU") - - def test_batchnorm1d_special(self): - c = torch.randn(BATCH_SIZE, 224) - model = nn.BatchNorm1d(224) - self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE) - - def test_batchnorm1d(self): - c = torch.randn(BATCH_SIZE, 224, 224) - model = nn.BatchNorm1d(224) - self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE) - - def test_batchnorm1d_noaffine(self): - c = torch.randn(BATCH_SIZE, 224) - model = nn.BatchNorm1d(224, affine=False) - self.run_model_test(model, train=False, input=c, batch_size=BATCH_SIZE) - - def test_batchnorm2d_noaffine(self): - c = torch.randn(128, 128, 1, 1) - model = nn.BatchNorm2d(128, affine=False) - self.run_model_test(model, train=False, input=c, batch_size=BATCH_SIZE) - - def test_batchnorm3d_noaffine(self): - c = torch.randn(128, 128, 1, 1, 1) - model = nn.BatchNorm3d(128, affine=False) - self.run_model_test(model, train=False, input=c, batch_size=BATCH_SIZE) - - def test_constant(self): - c = torch.randn(BATCH_SIZE, 3, 224, 224) - - class MyModel(torch.nn.Module): - def forward(self, input): - return input + c.type_as(input) - - self.run_model_test(MyModel(), train=False, batch_size=BATCH_SIZE) - - def test_consumed_bn(self): - underlying = nn.BatchNorm2d(3) - self.run_model_test(underlying, train=True, batch_size=BATCH_SIZE) - - def _test_index_generic(self, fn): - class MyModel(torch.nn.Module): - def forward(self, input): - return fn(input) - - m1 = torch.randn(3, 4, 5, 6, 7) - self.run_model_test(MyModel(), input=m1, train=False, batch_size=BATCH_SIZE) - - def test_index_1d(self): - self._test_index_generic(lambda input: input[0]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_index_2d_1dimslice(self): - self._test_index_generic(lambda input: input[0:1, :]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_index_2d_sliceint(self): - self._test_index_generic(lambda input: input[1, :]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_index_2d_neg_slice(self): - self._test_index_generic(lambda input: input[0:-1, :]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_index_2d_2dimslice(self): - self._test_index_generic(lambda input: input[0:1, 0:1]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_index_2d_neg_slice2dim(self): - self._test_index_generic(lambda input: input[0:-1, 0:-1]) - - def test_tensor_index_1d(self): - self._test_index_generic(lambda input: input[torch.tensor([0, 2])]) - - def test_tensor_index_2d_1dconstant(self): - self._test_index_generic(lambda input: input[1, torch.tensor([0, 2])]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_tensor_index_2d_1dslice(self): - self._test_index_generic(lambda input: input[torch.tensor([0, 2]), 0:1]) - - @skipIfUnsupportedOpsetVersion([10]) - def test_tensor_index_2d_1dslice_first(self): - self._test_index_generic(lambda input: input[1:3, torch.tensor([0, 2])]) - - def test_tensor_index_newaxis(self): - self._test_index_generic(lambda input: input[None, torch.tensor([0, 2])]) - - def test_tensor_index_advanced_indexing(self): - self._test_index_generic( - lambda input: input[ - :, - torch.tensor([[0, 2], [1, 1]]), - :, - torch.tensor([2, 1]), - torch.tensor([0, 3]), - ] - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_tensor_index_advanced_indexing_with_slice(self): - self._test_index_generic( - lambda input: input[ - :, torch.tensor([0, 2]), None, 2:4, torch.tensor([[1, 3], [4, 0]]) - ] - ) - self._test_index_generic( - lambda input: input[ - :, - torch.tensor([0, 2]), - torch.tensor([1]), - 2:4, - torch.tensor([[1], [4]]), - ] - ) - - def test_tensor_index_advanced_indexing_consecutive(self): - self._test_index_generic( - lambda input: input[ - :, torch.tensor([0, 2]), torch.tensor([[1, 3], [4, 0]]), None - ] - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_tensor_index_advanced_indexing_masked(self): - self._test_index_generic( - lambda input: input[ - :, - torch.tensor([1, 0, 1, 0], dtype=torch.uint8), - torch.tensor([[1, 3], [4, 0]]), - None, - ] - ) - - def test_chunk(self): - class MyModel(torch.nn.Module): - def forward(self, input): - # TODO: Why index? This returns a tuple and test runner doesn't - # support tuple comparison. - return input.chunk(8, dim=2)[-1] - - self.run_model_test(MyModel(), train=False, batch_size=BATCH_SIZE) - - def test_sqrt(self): - class MyModel(torch.nn.Module): - def forward(self, input): - return input.sqrt() - - input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9) - self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) - - def test_rsqrt(self): - class MyModel(torch.nn.Module): - def forward(self, input): - return input.rsqrt() - - input = torch.randn(4, 2, 3, requires_grad=True) - self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) - - def test_log(self): - class MyModel(torch.nn.Module): - def forward(self, input): - return input.log() - - input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9) - self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_erf(self): - class MyModel(torch.nn.Module): - def forward(self, input): - return input.erf() - - input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9) - self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) - - def test_trigonometry(self): - def test_func(name): - class MyModel(torch.nn.Module): - def forward(self, input): - return getattr(input, name)() - - input = torch.empty(BATCH_SIZE, 10, 10).uniform_() - self.run_model_test( - MyModel(), train=False, input=input, batch_size=BATCH_SIZE - ) - - test_func("cos") - test_func("sin") - test_func("tan") - test_func("acos") - test_func("asin") - test_func("atan") - - def test_addconstant(self): - class MyModel(torch.nn.Module): - def forward(self, input): - # TODO: Why index? This returns a tuple and test runner doesn't - # support tuple comparison. - return input + 1 - - self.run_model_test(MyModel(), train=False, batch_size=BATCH_SIZE) - - def test_subconstant(self): - class MyModel(torch.nn.Module): - def forward(self, input): - # TODO: Why index? This returns a tuple and test runner doesn't - # support tuple comparison. - return input - 1 - - self.run_model_test(MyModel(), train=False, batch_size=BATCH_SIZE) - - def test_arithmetic(self): - class ArithmeticModule(torch.nn.Module): - def forward(self, x): - x = x + 2 - x = x - 4 - x = x * 6 - x = x / 8 - return x - - x = torch.randn(2, 3, 4) - self.run_model_test( - ArithmeticModule(), input=x, train=False, batch_size=BATCH_SIZE - ) - - def test_embedding(self): - model = nn.Embedding(10, 3, padding_idx=-1) - input = torch.LongTensor(list(range(10))[::-1]) - self.run_model_test(model, train=False, input=input, batch_size=BATCH_SIZE) - - def test_constantpad2d(self): - model = nn.ConstantPad2d((1, 2, 3, 4), 3.5) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_reflectionpad2d(self): - model = nn.ReflectionPad2d((1, 2, 3, 4)) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_replicationpad2d(self): - model = nn.ReplicationPad2d((1, 2, 3, 4)) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_maxpool2d(self): - model = nn.MaxPool2d(5, padding=(1, 2)) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_maxpool2d_single_padding(self): - model = nn.MaxPool2d(5, padding=2) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - @skipIfUnsupportedOpsetVersion([10]) - def test_maxpool1d_ceil(self): - model = nn.MaxPool1d(3, 2, ceil_mode=True) - x = torch.randn(20, 16, 50, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedOpsetVersion([10]) - def test_maxpool2d_ceil(self): - model = nn.MaxPool2d(3, 2, ceil_mode=True) - x = torch.randn(20, 16, 50, 32, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedOpsetVersion([10]) - def test_maxpool3d_ceil(self): - model = nn.MaxPool3d(3, 2, ceil_mode=True) - x = torch.randn(20, 16, 50, 44, 31, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @unittest.skip("C2 and PyTorch have small difference in padding implementation") - def test_avgpool2d(self): - model = nn.AvgPool2d(5, padding=(2)) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_avgpool2d_with_count_include_pad_set_false(self): - model = nn.AvgPool2d(7, padding=(2), count_include_pad=False) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_avgpool2d_with_count_include_pad_set_true(self): - model = nn.AvgPool2d(7, padding=(2), count_include_pad=True) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - def test_avgpool2d_no_padding(self): - model = nn.AvgPool2d(5) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_avg_pool1D_ceil(self): - model = torch.nn.AvgPool1d(3, 2, ceil_mode=True) - x = torch.randn(1, 1, 7, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedOpsetVersion([10]) - def test_avg_pool2D_ceil(self): - model = torch.nn.AvgPool2d(3, 2, ceil_mode=True) - x = torch.randn(20, 16, 50, 32, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_avg_pool3D_ceil(self): - model = torch.nn.AvgPool3d(3, 2, ceil_mode=True) - x = torch.randn(20, 16, 50, 44, 31, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - def test_adaptive_avg_pool1D(self): - model = torch.nn.AdaptiveAvgPool1d(5) - x = torch.randn(20, 16, 50, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - def test_adaptive_avg_pool2D(self): - model = torch.nn.AdaptiveAvgPool2d((5, 4)) - x = torch.randn(20, 16, 50, 32, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - def test_adaptive_avg_pool3D(self): - model = torch.nn.AdaptiveAvgPool3d((5, 4, 3)) - x = torch.randn(20, 16, 50, 44, 30, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(8) - def test_adaptive_max_pool1D(self): - model = torch.nn.AdaptiveMaxPool1d(5) - x = torch.randn(20, 16, 50, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(8) - def test_adaptive_max_pool2D(self): - model = torch.nn.AdaptiveMaxPool2d((5, 4)) - x = torch.randn(20, 16, 50, 32, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(8) - def test_adaptive_max_pool3D(self): - model = torch.nn.AdaptiveMaxPool3d((5, 4, 3)) - x = torch.randn(20, 16, 50, 44, 30, requires_grad=True) - self.run_model_test(model, train=False, input=x, batch_size=BATCH_SIZE) - - def test_weight_norm(self): - model = nn.utils.weight_norm(nn.Conv1d(1, 1, 3)) - input = torch.randn(1, 1, 5, requires_grad=True) - self.run_model_test(model, train=True, batch_size=0, input=input, use_gpu=False) - - def test_mnist(self): - model = MNIST() - input = torch.randn(BATCH_SIZE, 1, 28, 28) - state_dict = None - # TODO: test with state_dict - self.run_model_test( - model, - train=False, - input=input, - batch_size=BATCH_SIZE, - state_dict=state_dict, - ) - - def test_mm(self): - class MyModel(torch.nn.Module): - def forward(self, m1, m2): - return torch.mm(m1, m2) - - m1 = torch.randn(3, 4) - m2 = torch.randn(4, 5) - self.run_model_test( - MyModel(), train=False, input=(m1, m2), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_addmm(self): - class MyModel(torch.nn.Module): - def forward(self, ma, m1, m2): - return torch.addmm(ma, m1, m2) - - ma = torch.randn(5) - m1 = torch.randn(3, 4) - m2 = torch.randn(4, 5) - self.run_model_test( - MyModel(), - train=False, - input=(ma, m1, m2), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - def test_fuse_addmm(self): - class AddmmModel(torch.nn.Module): - def forward(self, x): - return torch.mm(x, x) + x - - x = torch.randn(3, 3) - self.run_model_test( - AddmmModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_scalar_type(self): - class ArithmeticModel(torch.nn.Module): - def forward(self, x): - return x.size(0) * 2 * x - - x = torch.ones(2, 3, dtype=torch.float32) - self.run_model_test( - ArithmeticModel(), input=x, train=False, batch_size=BATCH_SIZE - ) - - class ReciprocalModel(torch.nn.Module): - def forward(self, x): - return torch.reciprocal(x) - - x = torch.tensor([2.0, 4.0], dtype=torch.double) - self.run_model_test( - ReciprocalModel(), input=x, train=False, batch_size=BATCH_SIZE - ) - - class ComparisonModel(torch.nn.Module): - def forward(self, x, y): - return x.ge(0.5) & y.le(2) - - x = torch.ones(2, 3, dtype=torch.int32) - y = torch.ones(2, 3, dtype=torch.float32) - self.run_model_test( - ComparisonModel(), input=(x, y), train=False, batch_size=BATCH_SIZE - ) - - class MatMulModel(torch.nn.Module): - def forward(self, x, y): - return torch.mm(x, y) - - x = torch.ones(3, 4) - y = torch.ones(4, 5) - self.run_model_test( - MatMulModel(), input=(x, y), train=False, batch_size=BATCH_SIZE - ) - - class AddMMModel(torch.nn.Module): - def forward(self, x): - return torch.mm(x, x) + x - - x = torch.ones(3, 3) - self.run_model_test(AddMMModel(), input=x, train=False, batch_size=BATCH_SIZE) - - # test for a pytorch optimization pass, see https://github.com/pytorch/pytorch/pull/7872 - def test_consecutive_transposes(self): - class MyModel(torch.nn.Module): - def forward(self, x): - return x.transpose(1, 2).transpose(2, 3) - - x = torch.randn(5, 6, 7, 8) - self.run_model_test( - MyModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_sum(self): - shape = (3, 4, 5) - for params in [{}] + [{"dim": i} for i in range(len(shape))]: - - class MyModel(torch.nn.Module): - def forward(self, x): - return torch.sum(x, **params) - - x = torch.randn(*shape) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_cumsum(self): - shape = (3, 4, 5) - for params in [{"dim": i} for i in range(len(shape))]: - - class MyModel(torch.nn.Module): - def forward(self, x): - return torch.cumsum(x, **params) - - x = torch.randn(*shape) - self.run_model_test( - MyModel(), - train=False, - input=(x), - batch_size=BATCH_SIZE, - use_gpu=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - - def test_cosine_similarity(self): - shape = (100, 128) - x = torch.randn(*shape) - y = torch.randn(*shape) - self.run_model_test( - torch.nn.CosineSimilarity(dim=1, eps=1e-6), - train=False, - input=(x, y), - batch_size=BATCH_SIZE, - use_gpu=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_lstm_constant_folding(self): - class LstmNet(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, bidirectional): - super().__init__() - self.lstm = nn.LSTM( - input_size, hidden_size, num_layers, bidirectional=bidirectional - ) - - def forward(self, input, initial_state): - return self.lstm(input, initial_state) - - def get_LstmNet_model_and_inputs( - input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional - ): - num_directions = 2 if bidirectional else 1 - model = LstmNet(input_size, hidden_size, num_layers, bidirectional) - input = torch.randn(seq_len, batch_size, input_size) - h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) - c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) - return model, (input, (h0, c0)) - - batch_size1 = 3 - model1, input1 = get_LstmNet_model_and_inputs(7, 3, 2, batch_size1, 5, True) - self.run_actual_test( - model1, - train=False, - batch_size=batch_size1, - input=input1, - use_gpu=False, - do_constant_folding=True, - ) - - batch_size2 = 4 - model2, input2 = get_LstmNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) - self.run_actual_test( - model2, - train=False, - batch_size=batch_size2, - input=input2, - use_gpu=False, - do_constant_folding=True, - ) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - def test_gru_constant_folding(self): - class GruNet(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, bidirectional): - super().__init__() - self.mygru = nn.GRU( - input_size, hidden_size, num_layers, bidirectional=bidirectional - ) - - def forward(self, input, initial_state): - out = self.mygru(input, initial_state) - return out - - def get_GruNet_model_and_inputs( - input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional - ): - num_directions = 2 if bidirectional else 1 - model = GruNet(input_size, hidden_size, num_layers, bidirectional) - input = torch.randn(seq_len, batch_size, input_size) - h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size) - return model, (input, h0) - - batch_size1 = 3 - model1, input1 = get_GruNet_model_and_inputs(7, 3, 2, batch_size1, 5, True) - self.run_actual_test( - model1, - train=False, - batch_size=batch_size1, - input=input1, - use_gpu=False, - do_constant_folding=True, - ) - - batch_size2 = 4 - model2, input2 = get_GruNet_model_and_inputs(5, 4, 3, batch_size2, 7, False) - self.run_actual_test( - model2, - train=False, - batch_size=batch_size2, - input=input2, - use_gpu=False, - do_constant_folding=True, - ) - - def test_repeat(self): - class MyModel(torch.nn.Module): - def forward(self, x): - return x.repeat(1, 2, 3, 4) - - x = torch.randn(4, 3, 2, 1, requires_grad=True) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_upsample(self): - x = torch.randn(1, 2, 3, 4, requires_grad=True) - model = nn.Upsample(size=[v * 2 for v in x.size()[2:]], mode="nearest") - self.run_model_test( - model, train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_interpolate_upsample(self): - class MyModel(torch.nn.Module): - def forward(self, x): - size = [v * 2 for v in x.size()[2:]] - # work around for now: turn the dynamic sizes into constant - size = [int(i) for i in size] - return nn.functional.interpolate(x, size=size, mode="nearest") - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - model = MyModel() - self.run_model_test( - model, train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedOpsetVersion([7, 8, 10]) - def test_interpolate_upsample_dynamic_sizes(self): - class MyModel(torch.nn.Module): - def forward(self, x): - size = [v * 2 for v in x.size()[2:]] - return nn.functional.interpolate(x, size=size, mode="nearest") - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - model = MyModel() - self.run_model_test( - model, train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_repeat_dim_overflow(self): - class MyModel(torch.nn.Module): - def forward(self, x): - return x.repeat(1, 2, 3, 4) - - x = torch.randn(1, 2, requires_grad=True) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_repeat_dynamic(self): - class MyModel(torch.nn.Module): - def forward(self, x, y): - return x.repeat(y.size()[0] // 2, y.size()[1] * 2) - - x = torch.randn(1, 2, requires_grad=True) - y = torch.randn(2, 4, requires_grad=True) - self.run_model_test( - MyModel(), - train=False, - input=(x, y), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x", "y"], - dynamic_axes={"x": [0, 1], "y": [0, 1]}, - ) - self.run_model_test( - MyModel(), - train=False, - input=(x, y), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[0], - ) - - def test_mean(self): - shape = (3, 4, 5) - for params in [{}] + [{"dim": i} for i in range(len(shape))]: - - class MyModel(torch.nn.Module): - def forward(self, x): - return torch.mean(x, **params) - - x = torch.randn(*shape) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, use_gpu=False - ) - - # TODO: Add test cases for prod once Caffe2 has support for ReduceProd - def test_softmax(self): - for i in range(2, 8): - for d in range(0, i - 1): - model = nn.Softmax(dim=d) - dims = [2] * (i - 2) + [3, 4] - input = torch.ones(*dims, requires_grad=True) - self.run_model_test( - model, train=False, batch_size=BATCH_SIZE, input=input - ) - - def test_softmax_dtype(self): - class SoftmaxModel(torch.nn.Module): - def forward(self, input): - return nn.functional.softmax(input, dim=0, dtype=torch.float64) - - x = torch.randn(1, 2, 3, requires_grad=True, dtype=torch.float32) - self.run_model_test(SoftmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_logsoftmax(self): - for i in range(7)[2:]: - model = nn.LogSoftmax(dim=i - 1) - dims = [2] * (i - 2) + [3, 4] - input = torch.ones(*dims, requires_grad=True) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input) - - def test_logsoftmax_dim(self): - for i in range(-4, 3): - model = nn.LogSoftmax(dim=i) - input = torch.randn(3, 4, 5, 6) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=input) - - def test_randn(self): - x = torch.randn(1, 2, 3, 4) - - class MyModule(torch.nn.Module): - def forward(self, x): - return (torch.randn(1, 2, 3, 4) + x).shape - - self.run_model_test( - MyModule(), - train=False, - input=(x), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - def test_rand(self): - x = torch.randn(1, 2, 3, 4) - - class MyModule(torch.nn.Module): - def forward(self, x): - return (torch.rand(1, 2, 3, 4) + x).shape - - self.run_model_test( - MyModule(), - train=False, - input=(x), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - def test_convtranspose(self): - model = nn.ConvTranspose2d( - 3, 3, 3, stride=3, bias=False, padding=1, output_padding=2 - ) - self.run_model_test(model, train=False, batch_size=BATCH_SIZE, atol=1e-7) - - def test_unsqueeze(self): - shape = (3, 4, 5) - # test negative dim as well. - for dim in range(-len(shape) - 1, len(shape) + 1): - - class MyModel(torch.nn.Module): - def forward(self, x): - return x.unsqueeze(dim) - - x = torch.randn(*shape) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, atol=1e-7 - ) - - def test_squeeze(self): - shape = (1, 1, 1) - # test negative dim as well - for dim in range(-len(shape), len(shape)): - - class MyModel(torch.nn.Module): - def forward(self, x): - return x.squeeze(dim) - - x = torch.randn(*shape) - self.run_model_test( - MyModel(), train=False, input=(x), batch_size=BATCH_SIZE, atol=1e-7 - ) - - # NB: InstanceNorm model includes unused weights, so skip this in TestCaffe2BackendEmbed - # TODO: We should have another pass to eliminate the unused initializers in ONNX models. - @skipIfEmbed - def test_instance_norm(self): - underlying = nn.InstanceNorm2d(3) - self.run_model_test(underlying, train=False, batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - def test_pixel_shuffle(self): - underlying = nn.PixelShuffle(4) - shape = (1, 32, 5, 5) - input = Variable(torch.randn(*shape), requires_grad=True) - self.run_model_test( - underlying, train=False, input=(input), batch_size=BATCH_SIZE - ) - - def test_dynamic_sizes(self): - class MyModel(torch.nn.Module): - def forward(self, x): - shape = torch.onnx.operators.shape_as_tensor(x) - new_shape = torch.cat((torch.LongTensor([-1]), shape[0].view(1))) - return torch.onnx.operators.reshape_from_tensor_shape(x, new_shape) - - x = torch.randn(3, 5, 7) - self.run_model_test( - MyModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_advanced_broadcast(self): - class MyModel(torch.nn.Module): - def forward(self, x, y): - return torch.mul(x, y) - - x = torch.randn(1, 5, 10) - y = torch.randn(1, 5, 1) - self.run_model_test( - MyModel(), train=False, input=(x, y), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_int8_export(self): - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.ByteTensor(3, 4).random_() - - def forward(self, x): - return x * self.param.float() - - import io - - f = io.BytesIO() - from torch.onnx import ExportTypes - - torch.onnx._export( - MyModel(), - (torch.rand(3, 4),), - f, - verbose=True, - export_type=ExportTypes.ZIP_ARCHIVE, - keep_initializers_as_inputs=True, - ) - - X = np.random.rand(3, 4).astype(np.float32) - - f.seek(0) - import caffe2.python.onnx.backend as c2 - - model = c2.prepare_zip_archive(f) - model.run(X) - - @skipIfUnsupportedOpsetVersion([10]) - def test_neg_slice(self): - class NegSlice(torch.nn.Module): - def forward(self, x): - return x[-1, :, :] - - x = torch.randn(3, 4, 5) - self.run_model_test( - NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_neg_slice_large(self): - class NegSlice(torch.nn.Module): - def forward(self, x): - return x[:, :, :, :, -3] - - x = torch.randn(3, 4, 5, 6, 7) - self.run_model_test( - NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - @unittest.skip("https://github.com/pytorch/pytorch/issues/10984") - @skipIfUnsupportedOpsetVersion([10]) - def test_neg_slice_large_negone(self): - class NegSlice(torch.nn.Module): - def forward(self, x): - return x[:, :, :, :, -1] - - x = torch.randn(3, 4, 5, 6, 7) - self.run_model_test( - NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedMinOpsetVersion(11) - def test_dynamic_slice(self): - class DynamicSliceExportMod(torch.nn.Module): - def forward(self, x): - results = [] - for i in range(4): - results.append(x[: x.size(0) - i, i : x.size(2), i:3]) - return tuple(results) - - x = torch.rand(5, 5, 5) - self.run_model_test( - DynamicSliceExportMod(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - @skipIfUnsupportedMinOpsetVersion(11) - def test_dynamic_slice_script(self): - class DynamicSliceModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - return x[1 : x.size(0)] - - module = DynamicSliceModel() - x = torch.rand(1, 2) - self.run_model_test( - DynamicSliceModel(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - @skipIfUnsupportedMinOpsetVersion(11) - def test_dynamic_slice_to_the_end(self): - class DynamicSliceExportMod(torch.nn.Module): - def forward(self, x): - results = [] - for i in range(4): - results.append(x[:, i:, x.size(2) - 5]) - return tuple(results) - - x = torch.rand(5, 5, 5) - self.run_model_test( - DynamicSliceExportMod(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - def test_unbind(self): - class UnbindModel(torch.nn.Module): - def forward(self, input): - return input.unbind() - - x = torch.randn(3, 4, 5) - self.run_model_test( - UnbindModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - class UnbindModel2(torch.nn.Module): - def forward(self, input): - _, out, _, _ = input.unbind(1) - return out - - x = torch.randn(3, 4, 5) - self.run_model_test( - UnbindModel2(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_inplace_zero(self): - class Zero_(torch.nn.Module): - def forward(self, x): - return x.zero_() - - x = torch.randn(2, 3, 4) - self.run_model_test( - Zero_(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - self.run_model_test( - Zero_(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_inplace_fill(self): - class Fill_(torch.nn.Module): - def forward(self, x): - return x.fill_(3) - - x = torch.randn(2, 3, 4) - self.run_model_test( - Fill_(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - self.run_model_test( - Fill_(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - # ConstantFill is a deprecated experimental op (used in opsets < 9). - # Shape inference does not cover this op. - @skipIfUnsupportedMinOpsetVersion(9) - def test_inplace_arithmetic(self): - class Arithmetic(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self): - x = torch.ones(2, 3, 4) - y = torch.ones(2, 3, 4) * 2 - x.add_(3) - y.mul_(x) - return x, y - - x = torch.ones(2, 3, 4) - y = torch.ones(2, 3, 4) * 2 - self.run_model_test( - Arithmetic(), train=False, input=(), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_tensor_factories(self): - class TensorFactory(torch.nn.Module): - def forward(self, x): - return torch.zeros(x.size()) + torch.ones(x.size()) - - x = torch.randn(2, 3, 4) - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - def test_tensor_factories_script(self): - class TensorFactory(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - return torch.zeros(x.shape, dtype=torch.float) + torch.ones( - x.shape, dtype=torch.float - ) - - x = torch.randn(2, 3, 4) - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=[], - ) - - def test_tensor_like_factories_script(self): - class TensorFactory(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - zeros = torch.zeros_like( - x, - dtype=torch.float, - layout=torch.strided, - device=torch.device("cpu"), - ) - ones = torch.ones_like( - x, - dtype=torch.float, - layout=torch.strided, - device=torch.device("cpu"), - ) - return zeros + ones - - x = torch.randn(2, 3, 4) - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - remained_onnx_input_idx = None if self.opset_version < 9 else [] - self.run_model_test( - TensorFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - remained_onnx_input_idx=remained_onnx_input_idx, - ) - - def test_full(self): - class FullModel(torch.nn.Module): - def forward(self, x): - return torch.full((3, 4), x, dtype=torch.long) - - x = torch.tensor(12) - self.run_model_test( - FullModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_full_script(self): - class FullClass(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - return torch.full((4, 5), x, dtype=torch.long) - - x = torch.tensor(12) - self.run_model_test( - FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_clamp(self): - class ClampModel(torch.nn.Module): - def forward(self, x): - return x.clamp(-0.5, 0.5) - - x = torch.randn(3, 4) - self.run_model_test( - ClampModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - class ClampMinModel(torch.nn.Module): - def forward(self, x): - return x.clamp(min=-0.5) - - x = torch.randn(3, 4) - self.run_model_test( - ClampMinModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - class ClampMaxModel(torch.nn.Module): - def forward(self, x): - return x.clamp(max=0.5) - - x = torch.randn(3, 4) - self.run_model_test( - ClampMaxModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_where_functional(self): - class WhereFunctional(torch.nn.Module): - def forward(self, x): - return torch.where(x > 2.0, x, torch.neg(x)) - - x = torch.randn(3, 4) - self.run_model_test( - WhereFunctional(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_where_method(self): - class WhereMethod(torch.nn.Module): - def forward(self, x): - return x.where(x > 2.0, torch.neg(x)) - - x = torch.randn(3, 4) - self.run_model_test( - WhereMethod(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - def test_data_dependent_zeros_factory(self): - class ZerosFactory(torch.nn.Module): - def forward(self, input): - return torch.cat( - [input, torch.zeros(input.size(0), 1).type_as(input)], dim=1 - ) - - x = torch.zeros(3, 4) - self.run_model_test( - ZerosFactory(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - def test_implicit_expand(self): - class ImplicitExpandExportMod(torch.nn.Module): - def forward(self, x): - return x + 1 - - x = torch.randn(3, 4) - self.run_model_test( - ImplicitExpandExportMod(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - def test_reduce_sum(self): - class ReduceSumNegativeIndices(torch.nn.Module): - def forward(self, x): - return x.sum(-1) - - x = torch.randn(2, 3, 4) - self.run_model_test( - ReduceSumNegativeIndices(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - def test_reduce_sum_multi_dim(self): - class ReduceSumMultipleAxes(torch.nn.Module): - def forward(self, x): - return x.sum(dim=(2, 3), keepdim=True) - - x = torch.randn(16, 3, 256, 256) - self.run_model_test( - ReduceSumMultipleAxes(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - # InstanceNorm model (used in the subgraph) includes unused weights, - # so skip this in TestCaffe2BackendEmbed - @skipIfEmbed - def test_group_norm(self): - c = torch.randn(BATCH_SIZE, 6, 224, 224) - model = nn.GroupNorm(3, 6, eps=0.0002) - self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE) - - # InstanceNorm model (used in the subgraph) includes unused weights, - # so skip this in TestCaffe2BackendEmbed - @skipIfEmbed - def test_group_norm_noaffine(self): - c = torch.randn(BATCH_SIZE, 6, 224, 224) - model = nn.GroupNorm(3, 6, eps=0.0002, affine=False) - self.run_model_test(model, train=True, input=c, batch_size=BATCH_SIZE) - - def test_rsub(self): - class RsubModel(torch.nn.Module): - def forward(self, x): - return 1 - x - - x = torch.randn(1, 2) - self.run_model_test( - RsubModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_isnan(self): - class IsNaNModel(torch.nn.Module): - def forward(self, input): - return torch.isnan(input) - - x = torch.tensor([1.0, float("nan"), 2.0]) - self.run_model_test( - IsNaNModel(), train=False, input=x, batch_size=BATCH_SIZE, use_gpu=False - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_scatter(self): - class ScatterModel(torch.nn.Module): - def forward(self, input, indices, values): - return input.scatter(1, indices, values) - - input = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) - indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64) - values = torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]) - self.run_model_test( - ScatterModel(), - train=False, - input=(input, indices, values), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - input = torch.zeros(3, 4, 5, 6) - indices = torch.tensor([[1, 0], [0, 2], [0, 1]], dtype=torch.int64) - indices = indices.view(3, 2, 1, 1).expand(3, 2, 5, 6) - values = torch.arange(3 * 2 * 5 * 6, dtype=torch.float32).view(3, 2, 5, 6) - self.run_model_test( - ScatterModel(), - train=False, - input=(input, indices, values), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - input = torch.zeros(3, 4, 2) - indices = torch.tensor([[[1, 0], [0, 2]], [[1, 1], [0, 1]], [[2, 1], [2, 2]]]) - values = torch.arange(3 * 2 * 2, dtype=torch.float32).view(3, 2, 2) - self.run_model_test( - ScatterModel(), - train=False, - input=(input, indices, values), - batch_size=BATCH_SIZE, - use_gpu=False, - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_flatten(self): - class FlattenModel(torch.nn.Module): - def forward(self, input): - return torch.flatten(input) - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_flatten2D(self): - class FlattenModel(torch.nn.Module): - def forward(self, input): - return torch.flatten(input, 1) - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_max(self): - class MaxModel(torch.nn.Module): - def forward(self, input): - return torch.max(input, dim=1) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_max_keepdim(self): - class MaxModel(torch.nn.Module): - def forward(self, input): - return torch.max(input, dim=1, keepdim=True) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_max_tensors(self): - class MaxModel(torch.nn.Module): - def forward(self, input, other): - return torch.max(input, other) - - x = torch.randn(4, 4, requires_grad=True) - y = torch.randn(4, 4, requires_grad=True) - self.run_model_test( - MaxModel(), train=False, input=(x, y), batch_size=BATCH_SIZE - ) - - def test_min(self): - class MinModel(torch.nn.Module): - def forward(self, input): - return torch.min(input, dim=1) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(MinModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_argmax(self): - class ArgmaxModel(torch.nn.Module): - def forward(self, input): - return torch.argmax(input, dim=1) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_argmax_none_dim(self): - class ArgmaxModel(torch.nn.Module): - def forward(self, input): - return torch.argmax(input) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_argmin(self): - class ArgminModel(torch.nn.Module): - def forward(self, input): - return torch.argmin(input, dim=1) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_argmin_none_dim(self): - class ArgminModel(torch.nn.Module): - def forward(self, input): - return torch.argmin(input) - - x = torch.randn(4, 4, requires_grad=True) - self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_reshape(self): - class ReshapeModel(torch.nn.Module): - def forward(self, input): - return input.reshape(1, 1) - - x = torch.randn(1, requires_grad=True) - self.run_model_test(ReshapeModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_reshape_as(self): - class ReshapeAsModel(torch.nn.Module): - def forward(self, input): - y = torch.randn(3, 1, 2, 1, requires_grad=False) - return input.reshape_as(y) - - x = torch.randn(2, 3, requires_grad=True) - self.run_model_test( - ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedOpsetVersion([10]) - def test_narrow(self): - class NarrowModel(torch.nn.Module): - def forward(self, input): - return torch.narrow(input, 0, 0, 2) - - x = torch.randn(3, 3, requires_grad=True) - self.run_model_test(NarrowModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_randn_like(self): - class RandNLikeModel(torch.nn.Module): - def forward(self, input): - return torch.randn_like(input) - - x = torch.randn(2, 3, 4, requires_grad=False) - model = RandNLikeModel() - onnxir, _ = do_export(model, x, keep_initializers_as_inputs=True) - onnx_model = onnx.ModelProto.FromString(onnxir) - prepared = c2.prepare(onnx_model) - caffe2_out = prepared.run(inputs=[x.cpu().numpy()]) - self.assertEqual(caffe2_out[0].shape, x.shape) - - def test_traced_ints(self): - A = 4 - H = 10 - W = 8 - img_count = 3 - - # in this model, the constant propagation in JIT doesn't work - # so we have ListConstruct in the symbolic - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(A, 4 * A, 1, stride=1) - - def forward(self, feature, im_info, anchors): - bbox_deltas = self.conv(feature) - a, b = torch.ops._caffe2.GenerateProposals( - feature, - bbox_deltas, - im_info, - anchors, - 2.0, - 6000, - 300, - 0.7, - 16, - True, - -90, - 90, - 1.0, - True, - ) - output = torch.ops._caffe2.RoIAlign( - feature, - a, - order="NCHW", - spatial_scale=1.0, - pooled_h=3, - pooled_w=3, - sampling_ratio=0, - aligned=False, - ) - return output - - feature = torch.empty(img_count, A, H, W) - im_info = torch.ones(img_count, 3, dtype=torch.float32) - anchors = torch.ones(A, 4, dtype=torch.float32) - inputs = (feature, im_info, anchors) - - model = MyModel() - with torch.no_grad(): - self.run_model_test( - MyModel(), train=False, input=inputs, batch_size=BATCH_SIZE - ) - - def test_c2_roi_align(self): - class MyModel(torch.nn.Module): - def forward(self, feature, rois): - roi_feature = torch.ops._caffe2.RoIAlign( - feature, - rois, - order="NCHW", - spatial_scale=1.0, - pooled_h=3, - pooled_w=3, - sampling_ratio=3, - aligned=False, - ) - return roi_feature - - def rand_roi(N, C, H, W): - return [ - float(int(N * np.random.rand())), - 0.5 * np.random.rand() * W, - 0.5 * np.random.rand() * H, - (0.5 + 0.5 * np.random.rand()) * W, - (0.5 + 0.5 * np.random.rand()) * H, - ] - - N, C, H, W = 1, 4, 10, 8 - feature = torch.randn(N, C, H, W) - rois = torch.tensor([rand_roi(N, C, H, W) for _ in range(10)]) - inputs = (feature, rois) - self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) - - def test_c2_generate_proposals(self): - class MyModel(torch.nn.Module): - def forward(self, scores, bbox_deltas, im_info, anchors): - a, b = torch.ops._caffe2.GenerateProposals( - scores, - bbox_deltas, - im_info, - anchors, - 2.0, - 6000, - 300, - 0.7, - 16, - True, - -90, - 90, - 1.0, - True, - ) - return a, b - - A = 4 - H = 10 - W = 8 - img_count = 3 - scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - bbox_deltas = torch.linspace( - 0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32 - ) - bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - im_info = torch.ones(img_count, 3, dtype=torch.float32) - anchors = torch.ones(A, 4, dtype=torch.float32) - inputs = (scores, bbox_deltas, im_info, anchors) - self.run_model_test(MyModel(), train=False, input=inputs, batch_size=3) - - def test_c2_bbox_transform(self): - class MyModel(torch.nn.Module): - def forward(self, rois, deltas, im_info): - a, b = torch.ops._caffe2.BBoxTransform( - rois, - deltas, - im_info, - weights=[1.0, 1.0, 1.0, 1.0], - apply_scale=False, - rotated=True, - angle_bound_on=True, - angle_bound_lo=-90, - angle_bound_hi=90, - clip_angle_thresh=0.5, - legacy_plus_one=True, - ) - return a, b - - roi_counts = [0, 2, 3, 4, 5] - batch_size = len(roi_counts) - total_rois = sum(roi_counts) - im_dims = np.random.randint(100, 600, batch_size) - rois = generate_rois_rotated(roi_counts, im_dims) - box_dim = 5 - num_classes = 7 - deltas = np.random.randn(total_rois, box_dim * num_classes).astype(np.float32) - im_info = np.zeros((batch_size, 3)).astype(np.float32) - im_info[:, 0] = im_dims - im_info[:, 1] = im_dims - im_info[:, 2] = 1.0 - im_info = torch.zeros((batch_size, 3)) - inputs = (torch.tensor(rois), torch.tensor(deltas), torch.tensor(im_info)) - self.run_model_test( - MyModel(), train=False, input=inputs, batch_size=3, use_gpu=False - ) - - # BoxWithNMSLimits has requirements for the inputs, so randomly generated inputs - # in Caffe2BackendTestEmbed doesn't work with this op. - @skipIfEmbed - def test_c2_box_with_nms_limits(self): - roi_counts = [0, 2, 3, 4, 5] - num_classes = 7 - rotated = False - angle_bound_on = True - clip_angle_thresh = 0.5 - rois, deltas, im_info = create_bbox_transform_inputs( - roi_counts, num_classes, rotated - ) - pred_bbox, batch_splits = ( - t.detach().numpy() - for t in torch.ops._caffe2.BBoxTransform( - torch.tensor(rois), - torch.tensor(deltas), - torch.tensor(im_info), - [1.0, 1.0, 1.0, 1.0], - False, - rotated, - angle_bound_on, - -90, - 90, - clip_angle_thresh, - legacy_plus_one=True, - ) - ) - class_prob = np.random.randn(sum(roi_counts), num_classes).astype(np.float32) - score_thresh = 0.5 - nms_thresh = 0.5 - topk_per_image = int(sum(roi_counts) / 2) - - class MyModel(torch.nn.Module): - def forward(self, class_prob, pred_bbox, batch_splits): - a, b, c, d, e, f = torch.ops._caffe2.BoxWithNMSLimit( - class_prob, - pred_bbox, - batch_splits, - score_thresh=score_thresh, - nms=nms_thresh, - detections_per_im=topk_per_image, - soft_nms_enabled=False, - soft_nms_method="linear", - soft_nms_sigma=0.5, - soft_nms_min_score_thres=0.001, - rotated=rotated, - cls_agnostic_bbox_reg=False, - input_boxes_include_bg_cls=True, - output_classes_include_bg_cls=True, - legacy_plus_one=True, - ) - return a, b, c, d, e, f - - inputs = ( - torch.tensor(class_prob), - torch.tensor(pred_bbox), - torch.tensor(batch_splits), - ) - self.run_model_test( - MyModel(), train=False, input=inputs, batch_size=3, use_gpu=False - ) - - def test_c2_inference_lstm(self): - num_layers = 4 - seq_lens = 6 - emb_lens = 10 - has_bias = True - batch_first = True - is_bidirectional = True - - class MyModel(torch.nn.Module): - def forward(self, lstm_in): - a, b, c = torch.ops._caffe2.InferenceLSTM( - lstm_in, num_layers, has_bias, batch_first, is_bidirectional - ) - return a, b, c - - num_directions = 2 - bsz = 5 - hidden_size = 7 - hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32) - inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32) - torch_lstm = torch.nn.LSTM( - emb_lens, - hidden_size, - batch_first=batch_first, - bidirectional=is_bidirectional, - bias=has_bias, - num_layers=num_layers, - ) - lstm_in = ( - [ - torch.from_numpy(inputs), - torch.from_numpy(hx), - torch.from_numpy(hx), - ] - + [param.detach() for param in torch_lstm._flat_weights], - ) - - self.run_model_test( - MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False - ) - - def test_tuple_input_output(self): - class TupleModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward( - self, a: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - return a - - x = (torch.randn(3, 4), torch.randn(4, 3)) - self.run_model_test( - TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - def test_nested_tuple_input_output(self): - class NestedTupleModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward( - self, - a: torch.Tensor, - b: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - return a + b[0] + b[1][0] + b[1][1] - - x = torch.randn(4, 5) - y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5))) - self.run_model_test( - NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE - ) - - def test_topk(self): - class TopKModel(torch.nn.Module): - def forward(self, input): - return torch.topk(input, 3) - - x = torch.arange(1.0, 6.0) - self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_topk_script(self): - class TopKModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, input): - return torch.topk(input, 3, dim=0) - - x = torch.randn(4, 3, requires_grad=True) - self.run_model_test(TopKModel(), train=False, input=(x,), batch_size=BATCH_SIZE) - - def test_floor(self): - class FloorModel(torch.nn.Module): - def forward(self, input): - return torch.floor(input) - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - self.run_model_test(FloorModel(), train=False, input=x, batch_size=BATCH_SIZE) - - def test_ceil(self): - class CeilModel(torch.nn.Module): - def forward(self, input): - return torch.ceil(input) - - x = torch.randn(1, 2, 3, 4, requires_grad=True) - self.run_model_test(CeilModel(), train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test__dim_arange(self): - class DimArange(torch.nn.Module): - def forward(self, input): - return torch._dim_arange(input, 1) - - x = torch.ones(5, 6) - self.run_model_test( - DimArange(), - train=False, - input=x, - batch_size=BATCH_SIZE, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_arange_end(self): - class ArangeScript(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, a): - return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a - - x = torch.randn(3, 4, requires_grad=True) - self.run_model_test( - ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - class ArangeModel(torch.nn.Module): - def forward(self, a): - return torch.arange(a.size(0), dtype=torch.float).view(-1, 1) + a - - self.run_model_test( - ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_arange_start_end(self): - class ArangeScript(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, a): - return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a - - x = torch.randn(3, 4, requires_grad=True) - self.run_model_test( - ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - class ArangeModel(torch.nn.Module): - def forward(self, a): - return torch.arange(2, a.size(0) + 2, dtype=torch.float).view(-1, 1) + a - - self.run_model_test( - ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_arange_start_end_step(self): - class ArangeScript(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, a): - return ( - torch.arange( - 2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float - ).view(-1, 1) - + a - ) - - x = torch.randn(3, 4, requires_grad=True) - self.run_model_test( - ArangeScript(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - class ArangeModel(torch.nn.Module): - def forward(self, a): - return ( - torch.arange( - 2, a.size(0) * a.size(1) + 2, a.size(1), dtype=torch.float - ).view(-1, 1) - + a - ) - - self.run_model_test( - ArangeModel(), train=False, input=(x,), batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_size(self): - class SizeModel(torch.nn.Module): - def forward(self, input): - return torch.arange(input.size(0)), torch.arange(input.size(-1)) - - x = torch.randn(5, 3, 2) - self.run_model_test( - SizeModel(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - input_names=["x"], - dynamic_axes={"x": [0, 1, 2]}, - ) - self.run_model_test( - SizeModel(), - train=False, - input=(x,), - batch_size=BATCH_SIZE, - remained_onnx_input_idx=[], - ) - - def test_log2(self): - class Log2Model(torch.nn.Module): - def forward(self, input): - return torch.log2(input) - - x = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9) - self.run_model_test(Log2Model(), train=False, input=x, batch_size=BATCH_SIZE) - - def test__sample_dirichlet(self): - class DirichletModel(torch.nn.Module): - def forward(self, input): - return torch._sample_dirichlet(input) - - x = torch.randn(2, 3, 4, requires_grad=False) - model = DirichletModel() - onnxir, _ = do_export( - model, - x, - keep_initializers_as_inputs=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - onnx_model = onnx.ModelProto.FromString(onnxir) - prepared = c2.prepare(onnx_model) - caffe2_out = prepared.run(inputs=[x.cpu().numpy()]) - self.assertEqual(caffe2_out[0].shape, x.shape) - - def test__standard_gamma(self): - class GammaModel(torch.nn.Module): - def forward(self, input): - return torch._standard_gamma(input) - - x = torch.randn(2, 3, 4, requires_grad=False) - model = GammaModel() - onnxir, _ = do_export( - model, - x, - keep_initializers_as_inputs=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - onnx_model = onnx.ModelProto.FromString(onnxir) - prepared = c2.prepare(onnx_model) - caffe2_out = prepared.run(inputs=[x.cpu().numpy()]) - self.assertEqual(caffe2_out[0].shape, x.shape) - - # The order of returned indices from Multinomial is undefined, so randomly generated inputs - # in Caffe2BackendTestEmbed doesn't work with this op. - @skipIfEmbed - def test_multinomial(self): - class Multinomial(torch.nn.Module): - def forward(self, weight): - return torch.multinomial(weight, 3, replacement=True) - - class MultinomialNoReplacement(torch.nn.Module): - def forward(self, weight): - return torch.multinomial(weight, 1) - - weight = torch.tensor([[0, 10, 0, 0], [0, 0, 100, 0]], dtype=torch.float) - self.run_model_test( - Multinomial(), train=False, input=weight, batch_size=BATCH_SIZE - ) - self.run_model_test( - MultinomialNoReplacement(), train=False, input=weight, batch_size=BATCH_SIZE - ) - - def test_prim_shape(self): - x = torch.randn(4, 5, requires_grad=True) - - @torch.jit.script - def view_by_prim_shape(x): - return x.view(x.shape) - - class PrimShapeModel(torch.nn.Module): - def forward(self, input): - return view_by_prim_shape(input) - - self.run_model_test( - PrimShapeModel(), train=False, input=x, batch_size=BATCH_SIZE - ) - - def test_and(self): - class AndModel(torch.nn.Module): - def forward(self, x, y): - return x & y - - x = torch.randint(0, 1, (3, 5), dtype=torch.bool) - y = torch.randint(0, 1, (3, 5), dtype=torch.bool) - self.run_model_test( - AndModel(), train=False, input=(x, y), batch_size=BATCH_SIZE - ) - - def test_or(self): - class OrModel(torch.nn.Module): - def forward(self, x, y): - return x | y - - x = torch.randint(0, 1, (3, 5), dtype=torch.bool) - y = torch.randint(0, 1, (3, 5), dtype=torch.bool) - self.run_model_test(OrModel(), train=False, input=(x, y), batch_size=BATCH_SIZE) - - def test_dropout(self): - class DropoutModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.dropout = torch.nn.Dropout(0.5) - - def forward(self, x): - return self.dropout(x) - - x = torch.randn(1, 2, 3) - self.run_model_test(DropoutModel(), train=False, input=x, batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_while(self): - class WhileModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - a = 0 - while a < 4: - a += 1 - return x + a - - model = WhileModel() - inputs = torch.zeros(1, 2, 3, dtype=torch.long) - self.run_model_test( - model, - train=False, - input=(inputs,), - batch_size=BATCH_SIZE, - ) - - def test_while_cond(self): - class WhileModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x, a): - b = a < 4 - while b: - a += b.to(torch.long) - b = a < 4 - return x + a - - model = WhileModel() - x = torch.zeros(1, 2, 3, dtype=torch.long) - a = torch.tensor([0], dtype=torch.long) - self.run_model_test(model, train=False, input=(x, a), batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - def test_loop(self): - class LoopModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - for i in range(5): - x = x + i - return x - - model = LoopModel() - inputs = torch.zeros(1, 2, 3, dtype=torch.long) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - def test_dynamic_loop(self): - class LoopModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - for i in range(x.size(2)): - x = x + i - return x - - model = LoopModel() - inputs = torch.zeros(1, 2, 3, dtype=torch.long) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedMinOpsetVersion(9) - def test_nested_loops(self): - class NestedLoopsModel(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, x): - for i in range(5): - a = 0 - while a < 4: - a += 1 - for j in range(a): - x = x + j - x = x + a - return x - - model = NestedLoopsModel() - inputs = torch.zeros(1, 2, 3, dtype=torch.long) - self.run_model_test( - model, - train=False, - input=(inputs,), - batch_size=BATCH_SIZE, - ) - - def test_select(self): - class SelectModel(torch.nn.Module): - def forward(self, x): - return torch.select(x, 0, 1) - - model = SelectModel() - inputs = torch.randn(3, 2, 1) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - def test_std(self): - class StandardDeviation(torch.nn.Module): - def forward(self, input): - return torch.std(input, unbiased=False) - - model = StandardDeviation() - inputs = torch.randn(2, 3, 4) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - def test_std_along_dims(self): - class StandardDeviationAlongDims(torch.nn.Module): - def forward(self, input): - return torch.std(input, dim=(0, 1), unbiased=False, keepdim=False) - - model = StandardDeviationAlongDims() - inputs = torch.randn(2, 3, 4) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_masked_fill(self): - class MaskedFillModel(torch.nn.Module): - def forward(self, x): - mask = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.uint8) - return x.masked_fill(mask, 2) - - x = torch.zeros(4, 2, 3, requires_grad=True) - self.run_model_test( - MaskedFillModel(), input=(x,), train=False, batch_size=BATCH_SIZE - ) - - class MaskedFillModel2(torch.nn.Module): - def forward(self, x): - return x.masked_fill(x > 3, -1) - - x = torch.arange(16).view(2, 2, 4).to(torch.float32) - self.run_model_test( - MaskedFillModel2(), input=(x,), train=False, batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(8) - def test_meshgrid(self): - class MeshgridModel(torch.nn.Module): - def forward(self, x, y, z): - return torch.meshgrid(x, y, z) - - x = torch.ones(3, requires_grad=True) - y = torch.zeros(4, requires_grad=True) - z = torch.ones(5, requires_grad=True) - model = MeshgridModel() - self.run_model_test(model, train=False, input=(x, y, z), batch_size=BATCH_SIZE) - - def test_remainder(self): - class RemainderModel(torch.nn.Module): - def forward(self, input, other): - return torch.remainder(input, other) - - x = torch.randn(4, 2, 3) - y = torch.randn(1, 2, 1) - model = RemainderModel() - self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE) - - def test_remainder_scalar(self): - class RemainderModel(torch.nn.Module): - def forward(self, input): - return torch.remainder(input, 2.55) - - inputs = torch.randint(10, (2, 3)) - model = RemainderModel() - self.run_model_test( - model, - train=False, - input=(inputs,), - batch_size=BATCH_SIZE, - ) - - def test_baddbmm(self): - class MyModule(torch.nn.Module): - def forward(self, input, batch1, batch2): - return torch.baddbmm( - input, batch1, batch2, alpha=torch.tensor(5), beta=3.5 - ) - - x = torch.randn(10, 3, 5) - batch1 = torch.randn(10, 3, 4) - batch2 = torch.randn(10, 4, 5) - self.run_model_test( - MyModule(), input=(x, batch1, batch2), train=False, batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_gelu(self): - class GeluModel(torch.nn.Module): - def forward(self, x): - return torch.nn.functional.gelu(x, approximate="none") - - model = GeluModel() - inputs = torch.randn(2, 4, 5, 6, requires_grad=True) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_tanh_gelu(self): - class GeluModel(torch.nn.Module): - def forward(self, x): - return torch.nn.functional.gelu(x, approximate="tanh") - - model = GeluModel() - inputs = torch.randn(2, 4, 5, 6, requires_grad=True) - self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_index_fill(self): - class IndexFillModel(torch.nn.Module): - def forward(self, input): - index = torch.tensor([2, 0]) - return input.index_fill(2, index, -1) - - x = torch.randn(3, 4, 5, requires_grad=True) - self.run_model_test( - IndexFillModel(), input=(x,), train=False, batch_size=BATCH_SIZE - ) - - @skipIfUnsupportedMinOpsetVersion(9) - def test_index_copy(self): - class IndexCopyModel(torch.nn.Module): - def forward(self, input): - index = torch.tensor([2, 0]) - source = torch.ones(3, 2, 5) - return input.index_copy(1, index, source) - - x = torch.randn(3, 4, 5, requires_grad=True) - self.run_model_test( - IndexCopyModel(), input=(x,), train=False, batch_size=BATCH_SIZE - ) - - -# a bit of metaprogramming to set up all the rnn tests - - -def make_test( - name, - base, - layer, - bidirectional, - initial_state, - variable_length, - dropout, - **extra_kwargs, -): - test_name = str( - "_".join( - [ - "test", - name, - layer[1], - bidirectional[1], - initial_state[1], - variable_length[1], - dropout[1], - ] - ) - ) - - @unittest.skip("Disabled due to onnx optimizer deprecation") - @skipIfUnsupportedOpsetVersion([10]) - @skipIfUnsupportedMinOpsetVersion(8) - def f(self): - self._dispatch_rnn_test( - base, - layers=layer[0], - bidirectional=bidirectional[0], - initial_state=initial_state[0], - packed_sequence=variable_length[0], - dropout=dropout[0], - **extra_kwargs, - ) - - f.__name__ = test_name - setattr(TestCaffe2Backend_opset9, f.__name__, f) - - -def setup_rnn_tests(): - layers_opts = [(1, "unilayer"), (3, "trilayer")] - bidirectional_opts = [(False, "forward"), (True, "bidirectional")] - initial_state_opts = [(True, "with_initial_state"), (False, "no_initial_state")] - variable_length_opts = [ - (0, "without_sequence_lengths"), - (1, "with_variable_length_sequences"), - (2, "with_batch_first_sequence_lengths"), - ] - dropout_opts = [(0.2, "with_dropout"), (0.0, "without_dropout")] - test_count = 0 - for ( - layer, - bidirectional, - initial_state, - variable_length, - dropout, - ) in itertools.product( - layers_opts, - bidirectional_opts, - initial_state_opts, - variable_length_opts, - dropout_opts, - ): - for base, name, extra_kwargs in ( - ("elman", "elman_relu", {"nonlinearity": "relu"}), - ("elman", "elman_tanh", {"nonlinearity": "tanh"}), - ("lstm", "lstm", {}), - ("gru", "gru", {}), - ): - make_test( - name, - base, - layer, - bidirectional, - initial_state, - variable_length, - dropout, - **extra_kwargs, - ) - test_count += 1 - - # sanity check that a representative example does exist - TestCaffe2Backend_opset9.test_gru_trilayer_forward_with_initial_state_without_sequence_lengths_with_dropout - - # make sure no one accidentally disables all the tests without - # noticing - assert test_count == 192, test_count - - -setup_rnn_tests() - -# add the same test suite as above, but switch embed_params=False -# to embed_params=True -TestCaffe2BackendEmbed_opset9 = type( - "TestCaffe2BackendEmbed_opset9", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), -) - -# opset 7 tests -TestCaffe2Backend_opset7 = type( - "TestCaffe2Backend_opset7", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, opset_version=7), -) -TestCaffe2BackendEmbed_opset7 = type( - "TestCaffe2BackendEmbed_opset7", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7), -) - -# opset 8 tests -TestCaffe2Backend_opset8 = type( - "TestCaffe2Backend_opset8", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, opset_version=8), -) -TestCaffe2BackendEmbed_opset8 = type( - "TestCaffe2BackendEmbed_opset8", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8), -) - -# opset 10 tests -TestCaffe2Backend_opset10 = type( - "TestCaffe2Backend_opset10", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, opset_version=10), -) - -TestCaffe2BackendEmbed_opset10 = type( - "TestCaffe2BackendEmbed_opset10", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10), -) - -# add the same test suite as above, but switch embed_params=False -# to embed_params=True -TestCaffe2BackendEmbed_opset9_new_jit_API = type( - "TestCaffe2BackendEmbed_opset9_new_jit_API", - (pytorch_test_common.ExportTestCase,), - dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), -) - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx_caffe2/test_pytorch_onnx_caffe2_quantized.py b/test/onnx_caffe2/test_pytorch_onnx_caffe2_quantized.py deleted file mode 100644 index 372961c000eddb..00000000000000 --- a/test/onnx_caffe2/test_pytorch_onnx_caffe2_quantized.py +++ /dev/null @@ -1,382 +0,0 @@ -# Owner(s): ["module: unknown"] - -import io - -import numpy as np -import onnx -import pytorch_test_common - -import caffe2.python.onnx.backend as c2 -import torch.ao.nn.quantized as nnq -import torch.nn as nn -import torch.onnx -from torch.testing._internal import common_utils - - -class TestQuantizedOps(pytorch_test_common.ExportTestCase): - def generic_test( - self, model, sample_inputs, input_names=None, decimal=3, relaxed_check=False - ): - torch.backends.quantized.engine = "qnnpack" - pt_inputs = tuple(torch.from_numpy(x) for x in sample_inputs) - model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") - q_model = torch.ao.quantization.prepare(model, inplace=False) - q_model = torch.ao.quantization.convert(q_model, inplace=False) - - traced_model = torch.jit.trace(q_model, pt_inputs) - buf = io.BytesIO() - torch.jit.save(traced_model, buf) - buf.seek(0) - q_model = torch.jit.load(buf) - - q_model.eval() - output = q_model(*pt_inputs) - - f = io.BytesIO() - torch.onnx.export( - q_model, - pt_inputs, - f, - input_names=input_names, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # Caffe2 doesn't support newer opset versions - opset_version=9, - ) - f.seek(0) - onnx_model = onnx.load(f) - caffe_res = c2.run_model(onnx_model, dict(zip(input_names, sample_inputs)))[0] - # Due to change in requantization logic for certain ops such conv, linear - # in pytorch's integration of qnnpack, numerics may have a mismatc with C2. - # This mismatch should not be off my more than 1. - # This flag helps us override default behavior under certain circumstances. - if relaxed_check: - output_diff = np.absolute(np.squeeze(output.detach().numpy()) - caffe_res) - max_diff = np.amax(output_diff) - - # This check had to be changed to account for changes in - # qnnpack's requant logic. - np.testing.assert_( - max_diff <= 1, "Maximum absolute difference must be less than 1" - ) - else: - np.testing.assert_almost_equal( - output.detach().numpy(), caffe_res, decimal=decimal - ) - - def generic_unary_test(self, op): - class QModule(torch.nn.Module): - def __init__(self, op): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.op = op - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = self.op(self.quant1(x)) - return self.dequant(res) - - x = np.random.random((1, 2)).astype("float32") - self.generic_test(QModule(op), (x,), input_names=["x"]) - - def test_quantized_add(self): - class QAddModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.quant2 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x, y): - res = torch.ops.quantized.add(self.quant1(x), self.quant2(y), 1.0, 0) - return self.dequant(res) - - x = np.random.random(2).astype("float32") - y = np.random.random(2).astype("float32") - self.generic_test(QAddModule(), (x, y), input_names=["x", "y"]) - - def test_quantized_relu(self): - self.generic_unary_test(torch.nn.ReLU()) - - def export_to_onnx(self, model, input, input_names): - traced = torch.jit.trace(model, input) - buf = io.BytesIO() - torch.jit.save(traced, buf) - buf.seek(0) - - model = torch.jit.load(buf) - f = io.BytesIO() - torch.onnx.export( - model, - input, - f, - input_names=input_names, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # Caffe2 doesn't support newer opset versions - opset_version=9, - ) - f.seek(0) - - onnx_model = onnx.load(f) - return onnx_model - - def test_qlinear_model(self): - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.qconfig = torch.ao.quantization.default_qconfig - self.fc1 = torch.ao.quantization.QuantWrapper( - torch.nn.Linear(5, 10).to(dtype=torch.float) - ) - - def forward(self, x): - x = self.fc1(x) - return x - - torch.backends.quantized.engine = "qnnpack" - qconfig = torch.ao.quantization.default_qconfig - model = LinearModel() - model.qconfig = qconfig - model = torch.ao.quantization.prepare(model) - model = torch.ao.quantization.convert(model) - - x_numpy = np.random.rand(1, 2, 5).astype(np.float32) - x = torch.from_numpy(x_numpy).to(dtype=torch.float) - outputs = model(x) - input_names = ["x"] - onnx_model = self.export_to_onnx(model, x, input_names) - - caffe_res = c2.run_model(onnx_model, dict(zip(input_names, x_numpy)))[0] - output_diff = np.absolute(np.squeeze(outputs.numpy()) - caffe_res) - max_diff = np.amax(output_diff) - - # Permute pytorch output to NHWC - # This check had to be changed to account for changes in - # qnnpack's requant logic. - np.testing.assert_( - max_diff <= 1, "Maximum absolute difference must be less than 1" - ) - - def test_qconv_model(self): - class ConvModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.qconfig = torch.ao.quantization.default_qconfig - self.fc1 = torch.ao.quantization.QuantWrapper( - torch.nn.Conv2d(3, 5, 2, bias=True).to(dtype=torch.float) - ) - - def forward(self, x): - x = self.fc1(x) - return x - - torch.backends.quantized.engine = "qnnpack" - qconfig = torch.ao.quantization.default_qconfig - model = ConvModel() - model.qconfig = qconfig - model = torch.ao.quantization.prepare(model) - model = torch.ao.quantization.convert(model) - - x_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) - x = torch.from_numpy(x_numpy).to(dtype=torch.float) - outputs = model(x) - input_names = ["x"] - onnx_model = self.export_to_onnx(model, x, input_names) - - y = np.expand_dims(x_numpy, axis=0) - caffe_res = c2.run_model(onnx_model, dict(zip(input_names, y)))[0] - output_diff = np.absolute(np.squeeze(outputs.numpy()) - caffe_res) - max_diff = np.amax(output_diff) - - # Permute pytorch output to NHWC - # This check had to be changed to account for changes in - # qnnpack's requant logic. - np.testing.assert_( - max_diff <= 1, "Maximum absolute difference must be less than 1" - ) - - def test_upsample(self): - class QUpsampleModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = torch.ao.nn.quantized.functional.interpolate( - self.quant1(x), size=[6, 8], mode="nearest" - ) - return self.dequant(res) - - x = np.random.rand(1, 2, 3, 4).astype("float32") - self.generic_test(QUpsampleModule(), (x,), input_names=["x"], decimal=5) - - def test_avg_pool2d(self): - class QAvgPool2dModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = torch.nn.functional.avg_pool2d( - self.quant1(x), kernel_size=2, stride=1, padding=0 - ) - return self.dequant(res) - - x = np.random.rand(1, 2, 8, 8).astype("float32") - self.generic_test( - QAvgPool2dModule(), (x,), input_names=["x"], relaxed_check=True - ) - - def test_reshape(self): - class QReshapeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = self.quant1(x).reshape((1, 2, 1, 12)) - return self.dequant(res) - - x = np.random.rand(1, 2, 3, 4).astype("float32") - self.generic_test(QReshapeModule(), (x,), input_names=["x"], decimal=5) - - def test_slice(self): - class QSliceModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - qx = self.quant1(x) - res = qx[:, 1:2] - return self.dequant(res) - - x = np.random.rand(1, 2, 3, 4).astype("float32") - self.generic_test(QSliceModule(), (x,), input_names=["x"], decimal=5) - - def test_cat(self): - class QConcatModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x, y): - res = torch.ops.quantized.cat( - [self.quant1(x), self.quant1(y)], dim=1, scale=1.0, zero_point=0 - ) - return self.dequant(res) - - x = np.random.rand(1, 2, 3, 4).astype("float32") - y = np.random.rand(1, 4, 3, 4).astype("float32") - self.generic_test( - QConcatModule(), - ( - x, - y, - ), - input_names=["x", "y"], - ) - - def test_max_pool2d(self): - class QMaxPool2dModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = torch.nn.functional.max_pool2d( - self.quant1(x), kernel_size=2, stride=1, padding=0 - ) - return self.dequant(res) - - x = np.random.rand(1, 2, 8, 8).astype("float32") - self.generic_test(QMaxPool2dModule(), (x,), input_names=["x"], decimal=5) - - def test_quantized_sigmoid(self): - self.generic_unary_test(torch.nn.Sigmoid()) - - def test_small_model(self): - class SimpleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.quant = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - self.func_add = nnq.FloatFunctional() - self.conv1 = nn.Conv2d(3, 2, 5, bias=None).to(dtype=torch.float) - self.act1 = nn.Sigmoid() - self.conv2 = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) - self.fc = nn.Linear(72, 10).to(dtype=torch.float) - self.fc.qconfig = None - - def forward(self, x): - x = self.quant(x) - x = self.func_add.add(x, x) - x = self.conv1(x) - x = self.act1(x) - x = self.conv2(x) - x = self.dequant(x) - x = x.reshape(-1, 72).contiguous() - x = self.fc(x) - return x - - x = np.random.rand(2, 3, 10, 10).astype("float32") - self.generic_test(SimpleModel(), (x,), input_names=["x"], relaxed_check=True) - - def test_sequential(self): - class ConvBNReLUModule(nn.Sequential): - def __init__(self): - super().__init__( - nn.Conv2d(3, 3, 1, 1, bias=False), - nn.BatchNorm2d(3), - nn.ReLU(inplace=False), - ) - - class ModelWithClassifierHead(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 3, 1) - self.relu1 = nn.ReLU(inplace=False) - layers = [] - for i in range(3): - layers.append(ConvBNReLUModule()) - self.features = nn.Sequential(*layers) - head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] - self.classifier = nn.Sequential(*head) - self.seq = nn.Sequential() - self.quant = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - x = self.quant(x) - x = self.conv1(x) - x = self.relu1(x) - x = self.features(x) - x = torch.reshape(x, (-1, 3 * 10 * 10)) - x = self.classifier(x) - x = self.seq(x) - x = self.dequant(x) - return x - - model = ModelWithClassifierHead().eval() - torch.ao.quantization.fuse_modules( - model, - [ - ["conv1", "relu1"], - ["features.0.0", "features.0.1", "features.0.2"], - ["features.1.0", "features.1.1", "features.1.2"], - ["features.2.0", "features.2.1", "features.2.2"], - ], - inplace=True, - ) - - x = np.random.rand(1, 3, 10, 10).astype("float32") - self.generic_test(model, (x,), input_names=["x"], relaxed_check=True) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx_caffe2/test_verify.py b/test/onnx_caffe2/test_verify.py deleted file mode 100644 index 05ca8084045933..00000000000000 --- a/test/onnx_caffe2/test_verify.py +++ /dev/null @@ -1,106 +0,0 @@ -# Owner(s): ["module: onnx"] - -from verify import verify - -import caffe2.python.onnx.backend as backend -import torch -from torch.autograd import Function -from torch.nn import Module, Parameter -from torch.testing._internal import common_utils - - -class TestVerify(common_utils.TestCase): - maxDiff = None - - def assertVerifyExpectFail(self, *args, **kwargs): - try: - verify(*args, **kwargs) - except AssertionError as e: - if str(e): - # substring a small piece of string because the exact message - # depends on system's formatting settings - # self.assertExpected(str(e)[:60]) - # NB: why we comment out the above check? because numpy keeps - # changing the error format, and we have to keep updating the - # expect files let's relax this constraint - return - else: - raise - # Don't put this in the try block; the AssertionError will catch it - self.assertTrue(False, msg="verify() did not fail when expected to") - - def test_result_different(self): - class BrokenAdd(Function): - @staticmethod - def symbolic(g, a, b): - return g.op("Add", a, b) - - @staticmethod - def forward(ctx, a, b): - return a.sub(b) # yahaha! you found me! - - class MyModel(Module): - def forward(self, x, y): - return BrokenAdd().apply(x, y) - - x = torch.tensor([1, 2]) - y = torch.tensor([3, 4]) - self.assertVerifyExpectFail(MyModel(), (x, y), backend) - - def test_jumbled_params(self): - class MyModel(Module): - def forward(self, x): - y = x * x - self.param = Parameter(torch.tensor([2.0])) - return y - - x = torch.tensor([1, 2]) - with self.assertRaisesRegex(RuntimeError, "state_dict changed"): - verify(MyModel(), x, backend) - - def test_dynamic_model_structure(self): - class MyModel(Module): - def __init__(self): - super().__init__() - self.iters = 0 - - def forward(self, x): - if self.iters % 2 == 0: - r = x * x - else: - r = x + x - self.iters += 1 - return r - - x = torch.tensor([1, 2]) - self.assertVerifyExpectFail(MyModel(), x, backend) - - def test_embedded_constant_difference(self): - class MyModel(Module): - def __init__(self): - super().__init__() - self.iters = 0 - - def forward(self, x): - r = x[self.iters % 2] - self.iters += 1 - return r - - x = torch.tensor([[1, 2], [3, 4]]) - self.assertVerifyExpectFail(MyModel(), x, backend) - - def test_explicit_test_args(self): - class MyModel(Module): - def forward(self, x): - if x.data.sum() == 1.0: - return x + x - else: - return x * x - - x = torch.tensor([[6, 2]]) - y = torch.tensor([[2, -1]]) - self.assertVerifyExpectFail(MyModel(), x, backend, test_args=[(y,)]) - - -if __name__ == "__main__": - common_utils.run_tests() From 5344c41d431042dfb9f4a8cfb23b84cfa1352569 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 17 Jun 2024 10:25:09 -0700 Subject: [PATCH 091/171] Use forked torchbench branch with pinned numpy (#128856) Adds pinned numpy commit to yolov3 dependencies to the existing pinned commit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128856 Approved by: https://github.com/huydhn, https://github.com/PaliC --- .ci/pytorch/common_utils.sh | 2 +- .ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_mlstm.sh | 2 +- .github/ci_commit_pins/torchbench.txt | 2 +- benchmarks/dynamo/Makefile | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 91c2d1b5dd3bd7..2f03e8c4255e64 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -191,7 +191,7 @@ function clone_pytorch_xla() { function checkout_install_torchbench() { local commit commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench + git clone https://github.com/eellison/benchmark torchbench pushd torchbench git checkout "$commit" diff --git a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh index 72496691286e4c..70c4be781e2886 100644 --- a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh +++ b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh @@ -9,7 +9,7 @@ test_cpu_speed_mini_sequence_labeler () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/pytorch/benchmark.git + git clone https://github.com/eellison/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh index 1693b00f17e2d1..9633f7dfdfae38 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_cudnn_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/pytorch/benchmark.git + git clone https://github.com/eellison/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh index 2e26b9902b868f..b8548f8206a9cb 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/pytorch/benchmark.git + git clone https://github.com/eellison/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh index a0617530194a16..e224dd27f74f4f 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_mlstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/pytorch/benchmark.git + git clone https://github.com/eellison/benchmark.git cd benchmark/ diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index 3df9dd6cf80389..8779f5b61aa9ba 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -d6015d42d9a1834bc7595c4bd6852562fb80b30b +pin_yolo_dep diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile index 720542f28608bd..dacddec4b2919c 100644 --- a/benchmarks/dynamo/Makefile +++ b/benchmarks/dynamo/Makefile @@ -10,7 +10,7 @@ clone-deps: && (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \ && (test -e FBGEMM || git clone --recursive https://github.com/pytorch/FBGEMM) \ && (test -e torchrec || git clone --recursive https://github.com/pytorch/torchrec) \ - && (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \ + && (test -e torchbenchmark || git clone --recursive https://github.com/eellison/benchmark torchbenchmark) \ ) pull-deps: clone-deps From c172b58fe01a48be758c2054d27848c3a405f54a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 17 Jun 2024 18:49:15 +0000 Subject: [PATCH 092/171] Revert "Update DALLE2_pytorch expected accuracy result on CPU (#128718)" This reverts commit fd27138c4a86bd763a6b8128d940a7c98f951603. Reverted https://github.com/pytorch/pytorch/pull/128718 on behalf of https://github.com/huydhn due to This has reverted back to the previous expected value for some reason https://hud.pytorch.org/pytorch/pytorch/commit/153362fbc9e8642fb851a4de3b99e3871a2cc714 ([comment](https://github.com/pytorch/pytorch/pull/128718#issuecomment-2174194219)) --- .../cpu_inductor_torchbench_freezing_inference.csv | 2 +- .../ci_expected_accuracy/cpu_inductor_torchbench_inference.csv | 2 +- .../dynamic_cpu_inductor_torchbench_inference.csv | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv index c577edbb8aa46d..8ae5d51a38ec35 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,eager_fail_to_run,0 +DALLE2_pytorch,model_fail_to_load,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 85a1e3b0751ec9..13db429816597f 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,eager_fail_to_run,0 +DALLE2_pytorch,model_fail_to_load,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index d6487b0ce21a40..2a71ec9ef5095a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -10,7 +10,7 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,eager_fail_to_run,0 +DALLE2_pytorch,model_fail_to_load,0 From 213eba7d2e9285171f447169de5c6f1d448c8bf6 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 17 Jun 2024 18:53:56 +0000 Subject: [PATCH 093/171] Configure mergebot via config (#128840) Fixes #ISSUE_NUMBER * Companion to https://github.com/pytorch/test-infra/pull/5312 * See the above for details + possible risks * Without the above PR, this should have no effects Pull Request resolved: https://github.com/pytorch/pytorch/pull/128840 Approved by: https://github.com/huydhn --- .github/pytorch-probot.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 0d624788fc61ee..7440791aff63bc 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -26,3 +26,4 @@ retryable_workflows: - windows-binary labeler_config: labeler.yml label_to_label_config: label_to_label.yml +mergebot: True From b181b58857462edaeafb92951e35a1214b6f8bb6 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 17 Jun 2024 08:55:40 -0700 Subject: [PATCH 094/171] Fix Storage.filename to not track the filename when storage was mmap-ed with MAP_PRIVATE (#128725) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128725 Approved by: https://github.com/albanD --- aten/src/ATen/MapAllocator.h | 4 ++++ test/test_tensor_creation_ops.py | 7 ++++--- torch/csrc/StorageMethods.cpp | 3 ++- torch/storage.py | 6 ++++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/MapAllocator.h b/aten/src/ATen/MapAllocator.h index f4a30edef62395..db1258beee525c 100644 --- a/aten/src/ATen/MapAllocator.h +++ b/aten/src/ATen/MapAllocator.h @@ -55,6 +55,10 @@ class TORCH_API MapAllocator { return base_ptr_; } + int flags() const { + return flags_; + } + static MapAllocator* fromDataPtr(const at::DataPtr&); static at::DataPtr makeDataPtr( c10::string_view filename, diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index fe40a2f16b6d71..fcef43ad943b46 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -3177,18 +3177,19 @@ def test_from_file(self, device, shared): dtype = torch.float64 t = torch.randn(2, 5, dtype=dtype, device=device) with tempfile.NamedTemporaryFile() as f: + expected_filename = f.name if shared else None t.numpy().tofile(f) t_mapped = torch.from_file(f.name, shared=shared, size=t.numel(), dtype=dtype) - self.assertTrue(t_mapped.storage().filename == f.name) + self.assertTrue(t_mapped.untyped_storage().filename == expected_filename) self.assertEqual(torch.flatten(t), t_mapped) s = torch.UntypedStorage.from_file(f.name, shared, t.numel() * dtype.itemsize) - self.assertTrue(s.filename == f.name) + self.assertTrue(s.filename == expected_filename) @onlyCPU def test_storage_filename(self, device): t = torch.randn(2, 5, device=device) - self.assertIsNone(t.storage().filename) + self.assertIsNone(t.untyped_storage().filename) # Class for testing random tensor creation ops, like torch.randint diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 540268d1522445..3f3079618d1f4c 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -624,7 +624,8 @@ static PyObject* THPStorage__get_filename(PyObject* self, PyObject* noargs) { const c10::DataPtr& data_ptr = self_.data_ptr(); at::MapAllocator* map_allocator = at::MapAllocator::fromDataPtr(data_ptr); - if (map_allocator == nullptr) { + if (map_allocator == nullptr || + !(map_allocator->flags() & at::ALLOCATOR_MAPPED_SHARED)) { Py_RETURN_NONE; } std::string filename = map_allocator->filename(); diff --git a/torch/storage.py b/torch/storage.py index c094ba5ac3e964..f3143796719354 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -363,8 +363,10 @@ def is_hpu(self): @property def filename(self) -> _Optional[str]: - """Returns the file name associated with this storage if the storage was memory mapped from a file. - or ``None`` if the storage was not created by memory mapping a file.""" + """Returns the file name associated with this storage. + + The file name will be a string if the storage is on CPU and was created via + :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise.""" return self._get_filename() @_share_memory_lock_protected From 1577328ea40fce57fd5253c7c7e34b66c13ff9b5 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 17 Jun 2024 19:24:09 +0000 Subject: [PATCH 095/171] Set bash shell on Windows (#128854) Attempt to fix the missing python3 command on the new Windows AMI https://github.com/pytorch/pytorch/actions/runs/9551494945/job/26325922503. I added the logic to copy python to python3 to make the command available, it worked with the previous AMI, but start to fail now and the cause is not clear (maybe it's not the AMI, but a new GitHub runner version) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128854 Approved by: https://github.com/kit1980, https://github.com/malfet, https://github.com/atalman --- .github/workflows/_win-build.yml | 4 ++++ .github/workflows/_win-test.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index bc381c50628d10..72112c99e7d54c 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -47,6 +47,9 @@ jobs: timeout-minutes: 240 outputs: test-matrix: ${{ steps.filter.outputs.test-matrix }} + defaults: + run: + shell: bash steps: # Duplicated in win-test because this MUST go before a checkout - name: Enable git symlinks on Windows and disable fsmonitor daemon @@ -89,6 +92,7 @@ jobs: - name: Parse ref id: parse-ref + shell: bash run: python3 .github/scripts/parse_ref.py - name: Get workflow job id diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 99d037f0355ce6..143088c195ed3f 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -41,6 +41,9 @@ jobs: fail-fast: false runs-on: ${{ matrix.runner }} timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} + defaults: + run: + shell: bash steps: # Duplicated in win-build because this MUST go before a checkout - name: Enable git symlinks on Windows and disable fsmonitor daemon @@ -224,6 +227,7 @@ jobs: - name: Parse ref id: parse-ref + shell: bash run: python3 .github/scripts/parse_ref.py - name: Uninstall PyTorch From 0f89e66d1745b8f4b304ebf46174bc726f0c28f5 Mon Sep 17 00:00:00 2001 From: Kurman Karabukaev Date: Mon, 17 Jun 2024 20:07:13 +0000 Subject: [PATCH 096/171] Validate logs are created by default (#128522) Summary: Make sure that logs are caputured in default settings Test Plan: ci Differential Revision: D58395812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128522 Approved by: https://github.com/d4l3k --- test/distributed/launcher/test_run.py | 31 ++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index ba58aec4387154..f71bffd527c1e2 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -6,6 +6,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import io import multiprocessing as mp import os import runpy @@ -14,7 +15,7 @@ import sys import tempfile import uuid -from contextlib import closing +from contextlib import closing, redirect_stderr, redirect_stdout from unittest import mock from unittest.mock import MagicMock, Mock, patch @@ -629,6 +630,34 @@ def test_init_method_env_with_torchelastic(self): ) # nothing to validate, just make sure it runs + def test_capture_logs_using_default_logs_specs(self): + run_id = str(uuid.uuid4().int) + nnodes = 1 + nproc_per_node = 4 + args = [ + f"--nnodes={nnodes}", + f"--nproc-per-node={nproc_per_node}", + f"--rdzv-id={run_id}", + "--redirect=3", + "--tee=3", + "--monitor-interval=1", + "--start-method=spawn", + "--no-python", + ] + + script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] + + captured_out = io.StringIO() + captured_err = io.StringIO() + with redirect_stdout(captured_out), redirect_stderr(captured_err): + with patch.dict( + os.environ, {"TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE": "[rank${rank}]: "} + ): + launch.main(args + script_args) + + for i in range(nproc_per_node): + self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue()) + if __name__ == "__main__": run_tests() From a59766ee058ba10d61e94c96daf2f7ded63efdb8 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 17 Jun 2024 20:50:22 +0000 Subject: [PATCH 097/171] replace `AT_ERROR(...)` with `TORCH_CHECK(false, ...)` (#128788) as per title. encountered the old-fashioned by chance Pull Request resolved: https://github.com/pytorch/pytorch/pull/128788 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 3a473495ff9f15..250fe68ff5e66f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1619,7 +1619,7 @@ Tensor alias_with_sizes_and_strides( Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { if (self.is_sparse()) { - AT_ERROR("reshape is not implemented for sparse tensors"); + TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } if (self.is_contiguous() && !self.is_mkldnn()) { @@ -1682,7 +1682,7 @@ Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_sha // minimize breakages. Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { if (self.is_sparse()) { - AT_ERROR("reshape is not implemented for sparse tensors"); + TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } DimVector shape = infer_size_dv(proposed_shape, self.numel()); From 8c06eae17eb470e3eb97f58cf6c0eddad26937f6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 15 Jun 2024 18:30:41 -0700 Subject: [PATCH 098/171] [GPT-benchmark] Add metric: compilation time for GPT models (#128768) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128768 Approved by: https://github.com/Chillee --- benchmarks/gpt_fast/generate.py | 40 ++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 3ec72bf1e3195e..92d00cb1bdb6b3 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -27,6 +27,7 @@ class GPTModelConfig: quantizer: type token_per_sec: float memory_bandwidth: float + compilation_time: float def device_sync(device): @@ -190,6 +191,7 @@ def run_experiment( aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} start = -1 + compilation_time = None for i in range(start, num_samples): device_sync(device=device) # MKG @@ -200,7 +202,8 @@ def run_experiment( ) if i == -1: - print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + compilation_time = time.perf_counter() - t0 + print(f"Compilation time: {compilation_time:.2f} seconds") continue device_sync(device=device) # MKG @@ -217,7 +220,7 @@ def run_experiment( print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") - return token_per_sec, memory_bandwidth + return token_per_sec, memory_bandwidth, compilation_time # token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. @@ -231,8 +234,9 @@ def run_llama2_7b_bf16(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 94, 1253, + 162, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -250,6 +254,14 @@ def run_llama2_7b_bf16(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] @@ -264,8 +276,9 @@ def run_llama2_7b_int8(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 144, 957, + 172, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -283,6 +296,14 @@ def run_llama2_7b_int8(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] @@ -298,8 +319,9 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): MixtralMoEWeightOnlyInt8QuantHandler, 175, 4129, + 162, ) - token_per_sec, memory_bandwidth = run_experiment(model) + token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -317,4 +339,12 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): model.mode, device, ), + Experiment( + model.name, + "compilation_time(s)", + model.compilation_time, + f"{compilation_time:.02f}", + model.mode, + device, + ), ] From a489792bb2d59ad7e36e0d3ae55074ce707b47e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 15 Jun 2024 18:30:44 -0700 Subject: [PATCH 099/171] [GPT-benchmark] Fix memory bandwidth for MoE (#128783) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128783 Approved by: https://github.com/Chillee ghstack dependencies: #128768 --- benchmarks/gpt_fast/generate.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index 92d00cb1bdb6b3..19c32d06be1047 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -3,8 +3,9 @@ import time from typing import Optional, Tuple -from mixtral_moe_model import Transformer as MixtralMoE +from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE from mixtral_moe_quantize import ( + ConditionalFeedForwardInt8, WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, ) from model import Transformer as LLaMA @@ -154,6 +155,7 @@ def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): return model.eval() +# Only count activated parameters and buffers. def _get_model_size(model): model_size = 0 for name, child in model.named_children(): @@ -164,6 +166,28 @@ def _get_model_size(model): for p in itertools.chain(child.parameters(), child.buffers()) ] ) + + # Remove the inactivated experts from the model size if this is mixture of experts + # architecture, since only activated experts are loaded. + if hasattr(model.config, "num_experts"): + config = model.config + for submodule in model.modules(): + if isinstance( + submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8) + ): + model_size -= ( + sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain( + submodule.parameters(), child.buffers() + ) + ] + ) + * (config.num_experts - config.num_activated_experts) + / config.num_experts + ) + return model_size @@ -318,7 +342,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): "int8", MixtralMoEWeightOnlyInt8QuantHandler, 175, - 4129, + 1280, 162, ) token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) From 8953725e6d68b3b7011626319a17fca5bd0b3e75 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 17 Jun 2024 21:10:55 +0000 Subject: [PATCH 100/171] [Inductor][FlexAttention] Tune backwards kernel block sizes (#128853) This replaces #128767 which somehow closed by mistake. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128853 Approved by: https://github.com/angelayi --- torch/_inductor/kernel/flex_attention.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 932bcd50b9203e..987dc6d89328b1 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -361,7 +361,7 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: return (64, 64, 4, 1) return (128, 128, 4, 3) elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 - return (32, 32, 4, 1) + return (64, 64, 4, 1) else: # modest hardware or extremely large head_dim return (16, 16, 4, 1) @@ -763,14 +763,13 @@ def flex_attention_backward(*args, **kwargs): configs: List[Tuple[int, int, int, int]] = [] configs.append(_get_default_config_bwd(query)) if config.max_autotune: - configs += [ - (128, 128, 4, 3), - (128, 128, 8, 1), - (64, 64, 4, 3), - (64, 64, 8, 1), - ] + for BLOCK1 in [32, 64]: + for BLOCK2 in [32, 64]: + for w in [4, 8]: + for s in [1, 3]: + configs.append((BLOCK1, BLOCK2, w, s)) - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + for BLOCK1, BLOCK2, num_warps, num_stages in configs: flex_attention_backward_template.maybe_append_choice( choices=choices, input_nodes=[ @@ -790,10 +789,10 @@ def flex_attention_backward(*args, **kwargs): call_sizes=query.get_size() + [key.get_size()[2]], num_stages=num_stages, num_warps=num_warps, - BLOCK_M1=BLOCK_M, - BLOCK_N1=BLOCK_N, - BLOCK_M2=BLOCK_N, - BLOCK_N2=BLOCK_M, + BLOCK_M1=BLOCK1, + BLOCK_N1=BLOCK1, + BLOCK_M2=BLOCK2, + BLOCK_N2=BLOCK2, BLOCK_DMODEL=query.get_size()[-1], # For now, we always assume the "sound" option SCORE_MOD_IS_LINEAR=False, From 163847b1bb5cc36a0915a189b2dd4cfbbfaf9c49 Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 17 Jun 2024 21:25:55 +0000 Subject: [PATCH 101/171] [1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/128675 Approved by: https://github.com/r-barnes --- test/onnx/test_export_modes.py | 20 --- test/onnx/test_operators.py | 6 +- test/onnx/test_pytorch_onnx_no_runtime.py | 80 +--------- test/onnx/test_utility_funs.py | 22 +-- test/quantization/core/test_quantized_op.py | 43 +----- test/test_jit.py | 52 +------ torch/_C/_onnx.pyi | 1 - torch/csrc/onnx/init.cpp | 2 - torch/onnx/__init__.py | 7 +- torch/onnx/_internal/jit_utils.py | 19 --- torch/onnx/symbolic_helper.py | 9 +- torch/onnx/symbolic_opset11.py | 25 ---- torch/onnx/symbolic_opset12.py | 2 - torch/onnx/symbolic_opset16.py | 3 - torch/onnx/symbolic_opset9.py | 137 ++---------------- torch/onnx/utils.py | 106 ++------------ torch/onnx/verification.py | 5 +- .../testing/_internal/common_quantization.py | 8 - torch/testing/_internal/common_utils.py | 15 -- 19 files changed, 33 insertions(+), 529 deletions(-) diff --git a/test/onnx/test_export_modes.py b/test/onnx/test_export_modes.py index 5bf84c1b409a0d..6d48b2f4578de5 100644 --- a/test/onnx/test_export_modes.py +++ b/test/onnx/test_export_modes.py @@ -86,26 +86,6 @@ def foo(a): x = torch.ones(3) torch.onnx.export(foo, (x,), f) - @common_utils.skipIfNoCaffe2 - @common_utils.skipIfNoLapack - def test_caffe2_aten_fallback(self): - class ModelWithAtenNotONNXOp(nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), - (x, y), - add_node_names=False, - do_constant_folding=False, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, - ) - - @common_utils.skipIfCaffe2 @common_utils.skipIfNoLapack def test_aten_fallback(self): class ModelWithAtenNotONNXOp(nn.Module): diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 99f0d533a61ca9..87ec424cf65d57 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -39,7 +39,7 @@ parse_args, ) from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack +from torch.testing._internal.common_utils import skipIfNoLapack unittest.TestCase.maxDiff = None @@ -414,7 +414,6 @@ def test_maxpool_indices(self): x = torch.randn(20, 16, 50) self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x) - @skipIfCaffe2 def test_at_op(self): x = torch.randn(3, 4) @@ -694,7 +693,6 @@ def test_batchnorm_noaffine(self): keep_initializers_as_inputs=True, ) - @skipIfCaffe2 def test_embedding_bags(self): emb_bag = nn.EmbeddingBag(10, 8) input = torch.tensor([1, 2, 3, 4]).long() @@ -949,7 +947,6 @@ def forward(self, input, other): other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8) self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18) - @skipIfCaffe2 def test_layer_norm_aten(self): model = torch.nn.LayerNorm([10, 10]) x = torch.randn(20, 5, 10, 10) @@ -1203,7 +1200,6 @@ def forward(self, x, y): torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version) # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. - @skipIfCaffe2 def test_aten_embedding_2(self): _onnx_opset_version = 12 diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 54fc1782515392..324806eaf0adfa 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -20,7 +20,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import OperatorExportTypes, symbolic_helper, utils +from torch.onnx import symbolic_helper, utils from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils @@ -394,7 +394,6 @@ def forward(self, input): for node in graph.nodes(): self.assertTrue(node.sourceRange()) - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_due_exception(self): def bad_clamp(g, self, min, max): return symbolic_helper._onnx_unsupported("Bad boy!") @@ -411,7 +410,6 @@ def forward(self, x): ) self.assertAtenOp(onnx_model, "clamp", "Tensor") - @common_utils.skipIfCaffe2 def test_clip_aten_fallback_explicit_request(self): class MyClip(torch.nn.Module): def forward(self, x): @@ -961,60 +959,6 @@ def forward(self, x, w): torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) - @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback_must_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN, - OperatorExportTypes.ONNX_ATEN_FALLBACK, - ): - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten_must_not_fallback(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize - for operator_export_type in ( - OperatorExportTypes.ONNX_ATEN_FALLBACK, - OperatorExportTypes.ONNX_ATEN, - ): - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=operator_export_type, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" - - @common_utils.skipIfCaffe2 def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): def forward(self, x, y): @@ -1037,7 +981,6 @@ def forward(self, x, y): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "linalg_qr") - @common_utils.skipIfCaffe2 def test_onnx_aten(self): class ModelWithAtenFmod(torch.nn.Module): def forward(self, x, y): @@ -1056,7 +999,6 @@ def forward(self, x, y): onnx_model = onnx.load(io.BytesIO(f.getvalue())) self.assertAtenOp(onnx_model, "fmod", "Tensor") - @common_utils.skipIfCaffe2 def test_onnx_aten_fallback_must_not_fallback(self): # For BUILD_CAFFE2=0, aten fallback only when not exportable class ONNXExportable(torch.nn.Module): @@ -1233,26 +1175,6 @@ def _export_to_onnx(model, input, input_names): _export_to_onnx(model, data, input_names) - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_linear(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Linear(5, 10, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 2, 5).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - - @common_quantization.skipIfNoFBGEMM - @common_utils.skipIfNoCaffe2 - def test_lower_graph_conv2d(self): - model = torch.ao.quantization.QuantWrapper( - torch.nn.Conv2d(3, 5, 2, bias=True) - ).to(dtype=torch.float) - data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - @common_quantization.skipIfNoFBGEMM @unittest.skip( "onnx opset9 does not support quantize_per_tensor and caffe2 \ diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 9ee4129879652f..e7c8f407810334 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -17,7 +17,6 @@ skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, ) -from verify import verify import torch import torch.onnx @@ -26,7 +25,7 @@ from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack +from torch.testing._internal.common_utils import skipIfNoLapack def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str: @@ -1623,25 +1622,6 @@ def forward(self, x): "Graph parameter names does not match model parameters.", ) - @skipIfNoCaffe2 - def test_modifying_params(self): - class MyModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.tensor([2.0])) - - def forward(self, x): - y = x * x - self.param.data.add_(1.0) - return y - - x = torch.tensor([1, 2]) - # Move import to local as caffe2 backend requires additional build flag, - # and is only used in this test case. - import caffe2.python.onnx.backend as backend - - verify(MyModel(), x, backend, do_constant_folding=False) - def test_fuse_conv_bn(self): class Fuse(torch.nn.Module): def __init__(self): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 5b86693e11c101..2e606938192dd2 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -23,7 +23,7 @@ from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_utils import TestCase -from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE +from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr @@ -4524,47 +4524,6 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight) - """ Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), - num_batches=st.integers(1, 5), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type): - pack_fn = torch.ops.quantized.embedding_bag_byte_prepack - unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack - - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type) - - """ Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), - optimized_qparams=st.booleans(), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type): - pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack - unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack - - # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1 - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type) - - """ Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """ - @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2") - @given(num_embeddings=st.integers(10, 100), - embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0), - optimized_qparams=st.booleans(), - data_type=st.sampled_from([np.float32, np.float16]),) - def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type): - pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack - unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack - - # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1 - self._test_embedding_bag_unpack_fn( - pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type) - def embedding_bag_rowwise_offsets_run( self, bit_rate, num_embeddings, diff --git a/test/test_jit.py b/test/test_jit.py index 13bdd07be6cd98..afecb5f390402d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -96,7 +96,7 @@ from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ - suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ + suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -15299,56 +15299,6 @@ def is_tensor_value(item): continue self.assertEqual(value, getattr(loaded, "_" + name)) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") - @unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support") - def test_old_models_bc(self): - model = { - 'archive/version': b'1', - 'archive/code/archive.py': - b''' - op_version_set = 0 - def forward(self, - _0: Tensor) -> Tensor: - _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu")) - result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"), - non_blocking=False, copy=False) - result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu")) - result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu")) - _2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1) - return _2 - ''', - 'archive/attributes.pkl': b'\x80\x02](e.', - 'archive/libs.py': b'op_version_set = 0\n', - 'archive/model.json': - b''' - { - "protoVersion":"2", - "mainModule":{ - "torchscriptArena":{ - "key":"code/archive.py" - }, - "name":"archive", - "optimize":true - }, - "producerName":"pytorch", - "producerVersion":"1.0", - "libs":{ - "torchscriptArena":{ - "key":"libs.py" - } - } - }'''} - with TemporaryFileName() as fname: - archive_name = os.path.basename(os.path.normpath(fname)) - with zipfile.ZipFile(fname, 'w') as archive: - for k, v in model.items(): - archive.writestr(k, v) - - with open(fname, "rb") as f: - fn = torch.jit.load(f) - - x = torch.zeros(10) - fn(x) def test_submodule_attribute_serialization(self): class S(torch.jit.ScriptModule): diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 2e8e5a0c661172..349e0b9ad12f0d 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -2,7 +2,6 @@ from enum import Enum -_CAFFE2_ATEN_FALLBACK: bool PRODUCER_VERSION: str class TensorProtoDataType(Enum): diff --git a/torch/csrc/onnx/init.cpp b/torch/csrc/onnx/init.cpp index b8bef342323c57..6b06eb649cae05 100644 --- a/torch/csrc/onnx/init.cpp +++ b/torch/csrc/onnx/init.cpp @@ -292,7 +292,5 @@ void initONNXBindings(PyObject* module) { .value("TRAINING", TrainingMode::TRAINING); onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION); - - onnx.attr("_CAFFE2_ATEN_FALLBACK") = false; } } // namespace torch::onnx diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 2b2f2bdae0de3c..3f013b12358420 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -1,12 +1,7 @@ # mypy: allow-untyped-defs from torch import _C from torch._C import _onnx as _C_onnx -from torch._C._onnx import ( - _CAFFE2_ATEN_FALLBACK, - OperatorExportTypes, - TensorProtoDataType, - TrainingMode, -) +from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode from . import ( # usort:skip. Keep the order instead of sorting lexicographically _deprecation, diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 13ae4209da5dd8..ed064f6f874d7a 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -12,7 +12,6 @@ import torch from torch import _C -from torch._C import _onnx as _C_onnx from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, registration @@ -329,14 +328,6 @@ def _scalar(x: torch.Tensor): return x[0] -@_beartype.beartype -def _is_caffe2_aten_fallback() -> bool: - return ( - GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and _C_onnx._CAFFE2_ATEN_FALLBACK - ) - - @_beartype.beartype def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): r"""Initializes the right attribute based on type of value.""" @@ -350,16 +341,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): if _is_onnx_list(value): kind += "s" - if aten and _is_caffe2_aten_fallback(): - if isinstance(value, torch.Tensor): - # Caffe2 proto does not support tensor attribute. - if value.numel() > 1: - raise ValueError("Should not pass tensor attribute") - value = _scalar(value) - if isinstance(value, float): - kind = "f" - else: - kind = "i" return getattr(node, f"{kind}_")(name, value) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 676c3d68048b0a..6d876486f642c8 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -537,10 +537,7 @@ def is_complex_value(x: _C.Value) -> bool: @_beartype.beartype def is_caffe2_aten_fallback() -> bool: - return ( - GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and _C_onnx._CAFFE2_ATEN_FALLBACK - ) + return False @_beartype.beartype @@ -592,9 +589,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]): @_beartype.beartype def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None: # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators - if _C_onnx._CAFFE2_ATEN_FALLBACK: - warnings.warn(f"ONNX export failed on {op} because {msg} not supported") - elif GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: + if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: _onnx_unsupported(f"{op}, {msg}", value) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index e562d5a47567c1..90963c4f17fa77 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -211,10 +211,6 @@ def index_put( indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] - if symbolic_helper.is_caffe2_aten_fallback(): - args = [self] + indices_list + [values, accumulate] - return g.at("index_put", *args) - accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: @@ -398,8 +394,6 @@ def __interpolate( def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): if symbolic_helper._maybe_get_const(sparse_grad, "i"): return symbolic_helper._unimplemented("gather", "sparse_grad == True") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("gather", self, dim, index, sparse_grad) return g.op("GatherElements", self, index, axis_i=dim) @@ -407,8 +401,6 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter(g: jit_utils.GraphContext, self, dim, index, src): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("scatter", self, dim, index, src, overload_name="src") src_type = _type_utils.JitScalarType.from_value(src) src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): @@ -898,8 +890,6 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim): stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.op("_caffe2::Range", stop) return arange(g, stop, 4, None, None, None) @@ -982,9 +972,6 @@ def mm(g: jit_utils.GraphContext, self, other): @_onnx_symbolic("aten::index") @_beartype.beartype def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index", self, index, overload_name="Tensor") - if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: @@ -1007,16 +994,6 @@ def index(g: jit_utils.GraphContext, self, index): @_beartype.beartype def index_fill(g: jit_utils.GraphContext, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "index_fill", - self, - index, - value, - overload_name="int_Scalar", - dim_i=dim_value, - ) - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -1030,8 +1007,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_beartype.beartype def index_copy(g: jit_utils.GraphContext, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 5a6bf720df36f3..cf24fe43247ca7 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -330,8 +330,6 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step): const_step ): return opset9.unfold(g, input, dimension, const_size, const_step) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) if sizedim is not None: diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index cd5829ada850d7..8df3d954ba4332 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -71,9 +71,6 @@ def grid_sampler( @symbolic_helper.parse_args("v", "i", "v", "v") @_beartype.beartype def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("scatter", self, dim, index, src, overload_name="src") - src_type = _type_utils.JitScalarType.from_value( src, _type_utils.JitScalarType.UNDEFINED ) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index b4c937ed3f66b7..f43a09aa4b1479 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -841,36 +841,18 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = @symbolic_helper.parse_args("v", "i", "none") @_beartype.beartype def cumsum(g: jit_utils.GraphContext, input, dim, dtype): - if symbolic_helper.is_caffe2_aten_fallback(): - if dtype.node().kind() != "prim::Constant": - return symbolic_helper._unimplemented("cumsum", "dtype", dtype) - return g.at("cumsum", input, dim_i=dim) - symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) @_onnx_symbolic("aten::_sample_dirichlet") @_beartype.beartype def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): - if symbolic_helper.is_caffe2_aten_fallback(): - if not symbolic_helper._is_none(generator): - return symbolic_helper._unimplemented( - "_sample_dirichlet", "We are not able to export generator", self - ) - return g.at("_sample_dirichlet", self) return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) @_onnx_symbolic("aten::_standard_gamma") @_beartype.beartype def _standard_gamma(g: jit_utils.GraphContext, self, generator): - if symbolic_helper.is_caffe2_aten_fallback(): - if not symbolic_helper._is_none(generator): - return symbolic_helper._unimplemented( - "_standard_gamma", "not able to export generator", self - ) - return g.at("_standard_gamma", self) - return symbolic_helper._onnx_unsupported("_standard_gamma", self) @@ -1007,19 +989,6 @@ def embedding_bag( return symbolic_helper._onnx_unsupported( "embedding_bag with per_sample_weights" ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "embedding_bag", - embedding_matrix, - indices, - offsets, - outputs=4, - scale_grad_by_freq_i=scale_grad_by_freq, - mode_i=mode, - sparse_i=sparse, - include_last_offset_i=include_last_offset, - padding_idx_i=padding_idx, - ) return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) @@ -1052,10 +1021,6 @@ def transpose(g: jit_utils.GraphContext, self, dim0, dim1): axes = list(range(rank)) axes[dim0], axes[dim1] = axes[dim1], axes[dim0] return g.op("Transpose", self, perm_i=axes) - elif symbolic_helper.is_caffe2_aten_fallback(): - # if we don't have dim information we cannot - # output a permute so use ATen instead - return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) else: raise errors.SymbolicValueError( "Unsupported: ONNX export of transpose for tensor of unknown rank.", @@ -2927,16 +2892,6 @@ def layer_norm( eps: float, cudnn_enable: bool, ) -> _C.Value: - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "layer_norm", - input, - weight, - bias, - normalized_shape_i=normalized_shape, - eps_f=eps, - cudnn_enable_i=cudnn_enable, - ) normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) return normalized @@ -3043,8 +2998,6 @@ def instance_norm( @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def unfold(g: jit_utils.GraphContext, input, dimension, size, step): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizes = symbolic_helper._get_tensor_sizes(input) # FIXME(justinchuby): Get rid of the try catch here to improve readability try: @@ -3119,9 +3072,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] - if symbolic_helper.is_caffe2_aten_fallback(): - args = [self] + indices_list + [values, accumulate] - return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") @@ -3136,16 +3086,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum @_beartype.beartype def index_fill(g: jit_utils.GraphContext, self, dim, index, value): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "index_fill", - self, - index, - value, - overload_name="int_Scalar", - dim_i=dim_value, - ) - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -3160,8 +3100,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_beartype.beartype def index_copy(g: jit_utils.GraphContext, self, dim, index, source): dim_value = symbolic_helper._parse_arg(dim, "i") - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -3220,10 +3158,6 @@ def type_as(g: jit_utils.GraphContext, self, other): to_i=other_dtype.onnx_type(), ) - if symbolic_helper.is_caffe2_aten_fallback(): - # We don't know the type of other, bail by emitting ATen - return g.at("type_as", self, other) - raise errors.SymbolicValueError( "Unsupported: ONNX export of type_as for tensor " "of unknown dtype. Please check if the dtype of the " @@ -3236,8 +3170,6 @@ def type_as(g: jit_utils.GraphContext, self, other): @symbolic_helper.parse_args("v", "v", "i", "f") @_beartype.beartype def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) cross = symbolic_helper._reducesum_helper( g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 ) @@ -3516,50 +3448,28 @@ def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): @symbolic_helper.parse_args("v", "v", "v", "i") @_beartype.beartype def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("conv_tbc", input, weight, bias, pad_i=pad) - else: - # input must have 3 dimensions, see: - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 - # input = (time, batch, in_channels) - # weight = (kernel_width, in_channels, out_channels) - # bias = (out_channels,) - input = g.op("Transpose", input, perm_i=[1, 2, 0]) - weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) - conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) - return g.op("Transpose", conv, perm_i=[2, 0, 1]) + # input must have 3 dimensions, see: + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 + # input = (time, batch, in_channels) + # weight = (kernel_width, in_channels, out_channels) + # bias = (out_channels,) + input = g.op("Transpose", input, perm_i=[1, 2, 0]) + weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) + conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) + return g.op("Transpose", conv, perm_i=[2, 0, 1]) @_onnx_symbolic("aten::_unique") @symbolic_helper.parse_args("v", "i", "i") @_beartype.beartype def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "_unique", - input, - sorted_i=sorted, - return_inverse_i=return_inverse, - outputs=2, - ) - else: - return symbolic_helper._onnx_unsupported("_unique", input) + return symbolic_helper._onnx_unsupported("_unique", input) @_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") @_beartype.beartype def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "_unique2", - input, - sorted_i=sorted, - return_inverse_i=return_inverse, - return_counts_i=return_counts, - outputs=3, - ) - symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) @@ -4973,11 +4883,8 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim): stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.op("_caffe2::Range", stop) - else: - # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) - return arange(g, stop, 4, None, None, None) + # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) + return arange(g, stop, 4, None, None, None) @_onnx_symbolic("aten::detach") @@ -5543,9 +5450,6 @@ def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): @_onnx_symbolic("aten::arange") @_beartype.beartype def arange(g: jit_utils.GraphContext, *args): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("arange", *args) - @_beartype.beartype def _get_arange_dtype(dtype): dtype = symbolic_helper._maybe_get_const(dtype, "i") @@ -5665,9 +5569,6 @@ def masked_fill_(g: jit_utils.GraphContext, self, mask, value): @_onnx_symbolic("aten::index") @_beartype.beartype def index(g: jit_utils.GraphContext, self, index): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("index", self, index, overload_name="Tensor") - if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: @@ -6083,17 +5984,6 @@ def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "no def group_norm( g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled ): - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at( - "group_norm", - input, - weight, - bias, - num_groups_i=num_groups, - eps_f=eps, - cudnn_enabled_i=cudnn_enabled, - ) - channel_size = symbolic_helper._get_tensor_dim_size(input, 1) if channel_size is not None: assert channel_size % num_groups == 0 @@ -6169,9 +6059,6 @@ def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): norm_v = norm(g, weight_v, 2, axes, 1) div = g.op("Div", weight_v, norm_v) return g.op("Mul", div, weight_g) - if symbolic_helper.is_caffe2_aten_fallback(): - return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) - raise errors.SymbolicValueError( "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", weight_v, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0d02fabd1beb5b..870a599aebce21 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -11,7 +11,6 @@ import inspect import io import re -import textwrap import typing import warnings from typing import ( @@ -681,27 +680,6 @@ def _optimize_graph( _C._jit_pass_onnx_unpack_quantized_weights( graph, params_dict, symbolic_helper.is_caffe2_aten_fallback() ) - if symbolic_helper.is_caffe2_aten_fallback(): - # Insert permutes before and after each conv op to ensure correct order. - _C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict) - - # Find consecutive permutes that are no-ops and remove them. - _C._jit_pass_custom_pattern_based_rewrite_graph( - textwrap.dedent( - """\ - graph(%Pi): - %Pq = quantized::nhwc2nchw(%Pi) - %Pr = quantized::nchw2nhwc(%Pq) - return (%Pr)""" - ), - textwrap.dedent( - """\ - graph(%Ri): - return (%Ri)""" - ), - graph, - ) - # onnx only supports tensors, so we turn all out number types into tensors _C._jit_pass_erase_number_types(graph) if GLOBALS.onnx_shape_inference: @@ -734,18 +712,9 @@ def _optimize_graph( graph = _C._jit_pass_canonicalize(graph) _C._jit_pass_lint(graph) if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError as exc: - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and exc.args[0] - == "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!" - ): - # Caffe2 builds can have UNKNOWN_SCALAR for some tensors - pass + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) return graph @@ -783,17 +752,6 @@ def warn_on_static_input_change(input_states): @_beartype.beartype def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type): """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX.""" - if ( - operator_export_type is not operator_export_type.ONNX - and _C_onnx._CAFFE2_ATEN_FALLBACK - ): - if arg_value is True: - warnings.warn( - f"'{arg_name}' can be set to True only when 'operator_export_type' is " - "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', " - f"'{arg_name}' argument will be ignored." - ) - arg_value = False return arg_value @@ -1298,18 +1256,9 @@ def _model_to_graph( _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) if GLOBALS.onnx_shape_inference: - try: - _C._jit_pass_onnx_graph_shape_type_inference( - graph, params_dict, GLOBALS.export_onnx_opset_version - ) - except RuntimeError as exc: - if ( - _C_onnx._CAFFE2_ATEN_FALLBACK - and exc.args[0] - == "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!" - ): - # Caffe2 builds can have UNKNOWN_SCALAR for some tensors - pass + _C._jit_pass_onnx_graph_shape_type_inference( + graph, params_dict, GLOBALS.export_onnx_opset_version + ) params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) @@ -1612,15 +1561,6 @@ def _export( if export_type is None: export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - # Discussed deprecation with Nikita Shulga and Sergii Dymchenko from Meta - if _C_onnx._CAFFE2_ATEN_FALLBACK: - warnings.warn( - "Caffe2 ONNX exporter is deprecated in version 2.0 and will be " - "removed in 2.2. Please use PyTorch 2.1 or older for this capability.", - category=FutureWarning, - stacklevel=2, - ) - if isinstance(model, torch.nn.DataParallel): raise ValueError( "torch.nn.DataParallel is not supported by ONNX " @@ -1655,10 +1595,7 @@ def _export( "no local function support. " ) if not operator_export_type: - if _C_onnx._CAFFE2_ATEN_FALLBACK: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - else: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX + operator_export_type = _C_onnx.OperatorExportTypes.ONNX # By default, training=TrainingMode.EVAL, # which is good because running a model in training mode could result in @@ -1904,21 +1841,12 @@ def _should_aten_fallback( is_aten_fallback_export = ( operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK ) - is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK if not name.startswith("aten::"): return False - if is_caffe2_build: - if ( - is_onnx_aten_export or is_aten_fallback_export - ) and not is_exportable_aten_op: - return True - else: - if is_onnx_aten_export or ( - is_aten_fallback_export and not is_exportable_aten_op - ): - return True + if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op): + return True return False @@ -1968,7 +1896,7 @@ def wrapper(graph_context: jit_utils.GraphContext, *args, **kwargs): def _get_aten_op_overload_name(n: _C.Node) -> str: # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds schema = n.schema() - if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback(): + if not schema.startswith("aten::"): return "" return _C.parse_schema(schema).overload_name @@ -2032,14 +1960,7 @@ def _run_symbolic_function( ) try: - # Caffe2-specific: Quantized op symbolics are registered for opset 9 only. - if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9: - symbolic_caffe2.register_quantized_ops("caffe2", opset_version) - - if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback(): - domain = "caffe2" - else: - domain = namespace + domain = namespace symbolic_function_name = f"{domain}::{op_name}" symbolic_function_group = registration.registry.get_function_group( @@ -2073,10 +1994,7 @@ def _run_symbolic_function( except RuntimeError: if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH: return None - elif ( - operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - and not symbolic_helper.is_caffe2_aten_fallback() - ): + elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK` attrs = { k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k) diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 95ed873bf63358..38a23893a8ba50 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -633,10 +633,7 @@ def _onnx_graph_from_model( utils._setup_trace_module_map(model, export_modules_as_functions) if not operator_export_type: - if _C_onnx._CAFFE2_ATEN_FALLBACK: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK - else: - operator_export_type = _C_onnx.OperatorExportTypes.ONNX + operator_export_type = _C_onnx.OperatorExportTypes.ONNX GLOBALS.export_onnx_opset_version = opset_version GLOBALS.operator_export_type = operator_export_type diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 0c27032b9871b8..f7a5016bd8f41f 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -329,14 +329,6 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if not torch.onnx._CAFFE2_ATEN_FALLBACK: - raise unittest.SkipTest(reason) - else: - fn(*args, **kwargs) - return wrapper - def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bb5f3fa8e33084..2d5ea4a6c64ffd 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1252,8 +1252,6 @@ def TemporaryDirectoryName(suffix=None): TEST_Z3 = _check_module_exists('z3') -BUILD_WITH_CAFFE2 = torch.onnx._CAFFE2_ATEN_FALLBACK - def split_if_not_empty(x: str): return x.split(",") if len(x) != 0 else [] @@ -1886,19 +1884,6 @@ def skipIfNotRegistered(op_name, message): """ return unittest.skip("Pytorch is compiled without Caffe2") -def _decide_skip_caffe2(expect_caffe2, reason): - def skip_dec(func): - @wraps(func) - def wrapper(self): - if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2: - raise unittest.SkipTest(reason) - return func(self) - return wrapper - return skip_dec - -skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2") -skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available") - def skipIfNoSciPy(fn): @wraps(fn) def wrapper(*args, **kwargs): From 1fd7496ab2e66ac116a801d9aef54915230dbe44 Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Mon, 17 Jun 2024 21:58:46 +0000 Subject: [PATCH 102/171] [MTIA] Fix synchronize API (#128714) Reviewed By: fenypatel99 Differential Revision: D58590313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128714 Approved by: https://github.com/aaronenyeshi --- torch/csrc/mtia/Module.cpp | 2 +- torch/mtia/__init__.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 84cc11f7187593..63cfae19725521 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -56,7 +56,7 @@ void initModule(PyObject* module) { return at::detail::getMTIAHooks().getCurrentStream(device_index); }); - m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) { + m.def("_mtia_deviceSynchronize", []() { torch::utils::device_lazy_init(at::kMTIA); at::detail::getMTIAHooks().deviceSynchronize( at::detail::getMTIAHooks().getCurrentDevice()); diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index f9554a9bcb277f..1bd7d2a9b7c6f1 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -107,9 +107,10 @@ def is_available() -> bool: return device_count() > 0 -def synchronize() -> None: +def synchronize(device: Optional[_device_t] = None) -> None: r"""Waits for all jobs in all streams on a MTIA device to complete.""" - return torch._C._mtia_deviceSynchronize() + with torch.mtia.device(device): + return torch._C._mtia_deviceSynchronize() def device_count() -> int: From 7baf32b5e7440cb6c32b6ecf5dad0454bff39794 Mon Sep 17 00:00:00 2001 From: Shengbao Zheng Date: Mon, 17 Jun 2024 22:07:40 +0000 Subject: [PATCH 103/171] [c10d] fix p2p group commsplit (#128803) Summary: For PointToPoint(sendrecv), the deviceId is lower_rank:higher_rank. This means a p2p group cannot be created through commSplit since it cannot find a parent. Fix this by using the right device key of current rank. Differential Revision: D58631639 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128803 Approved by: https://github.com/shuqiangzhang --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e7699b55245147..d293c4d470b837 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2118,7 +2118,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( // Find a valid, healthy communicator to split from if possible. std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; - auto dit = other_comms.find(deviceKey); + auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { auto& parentComm = dit->second; if (parentComm != nullptr && !parentComm->isAborted()) { From 1835e3beab7e6e019b2a61137779297bfc3852ae Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 17 Jun 2024 22:20:33 +0000 Subject: [PATCH 104/171] Fix the inductor ci (#128879) Fix the torchbench+inductor ci on trunk due to recent upgrade to numpy 2.0.0rc1. We have to remove DALLE2_pytorch model, since it depends on embedding-reader, which is not compatible with numpy>2: https://github.com/rom1504/embedding-reader/blob/main/requirements.txt#L3 Fixes #128845 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128879 Approved by: https://github.com/eellison --- .ci/pytorch/common_utils.sh | 2 +- .ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_lstm.sh | 2 +- .ci/pytorch/perf_test/test_gpu_speed_mlstm.sh | 2 +- .github/ci_commit_pins/torchbench.txt | 2 +- benchmarks/dynamo/Makefile | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 2f03e8c4255e64..91c2d1b5dd3bd7 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -191,7 +191,7 @@ function clone_pytorch_xla() { function checkout_install_torchbench() { local commit commit=$(get_pinned_commit torchbench) - git clone https://github.com/eellison/benchmark torchbench + git clone https://github.com/pytorch/benchmark torchbench pushd torchbench git checkout "$commit" diff --git a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh index 70c4be781e2886..72496691286e4c 100644 --- a/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh +++ b/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh @@ -9,7 +9,7 @@ test_cpu_speed_mini_sequence_labeler () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh index 9633f7dfdfae38..1693b00f17e2d1 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_cudnn_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_cudnn_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh index b8548f8206a9cb..2e26b9902b868f 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_lstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_lstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh index e224dd27f74f4f..a0617530194a16 100644 --- a/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh +++ b/.ci/pytorch/perf_test/test_gpu_speed_mlstm.sh @@ -9,7 +9,7 @@ test_gpu_speed_mlstm () { export OMP_NUM_THREADS=4 export MKL_NUM_THREADS=4 - git clone https://github.com/eellison/benchmark.git + git clone https://github.com/pytorch/benchmark.git cd benchmark/ diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index 8779f5b61aa9ba..4a60ff3d38d408 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -pin_yolo_dep +0dab1dd97709096e8129f8a08115ee83f64f2194 diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile index dacddec4b2919c..720542f28608bd 100644 --- a/benchmarks/dynamo/Makefile +++ b/benchmarks/dynamo/Makefile @@ -10,7 +10,7 @@ clone-deps: && (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \ && (test -e FBGEMM || git clone --recursive https://github.com/pytorch/FBGEMM) \ && (test -e torchrec || git clone --recursive https://github.com/pytorch/torchrec) \ - && (test -e torchbenchmark || git clone --recursive https://github.com/eellison/benchmark torchbenchmark) \ + && (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \ ) pull-deps: clone-deps From 3b8c9b8ab11682b958dfe002d7106d94cf75ef7a Mon Sep 17 00:00:00 2001 From: atalman Date: Mon, 17 Jun 2024 22:51:12 +0000 Subject: [PATCH 105/171] [Docker Release] Test if pytorch was compiled with CUDA before pushing to repo (#128852) Related to: https://github.com/pytorch/pytorch/issues/125879 Would check if we are compiled with CUDA before publishing CUDA Docker nightly image Test ``` #18 [conda-installs 5/5] RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo "Is torch compiled with cuda: ${IS_CUDA}"; if test "${IS_CUDA}" != "True" -a ! -z "12.4.0"; then exit 1; fi #18 1.656 Is torch compiled with cuda: False #18 ERROR: process "/bin/sh -c IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo \"Is torch compiled with cuda: ${IS_CUDA}\"; if test \"${IS_CUDA}\" != \"True\" -a ! -z \"${CUDA_VERSION}\"; then \texit 1; fi" did not complete successfully: exit code: 1 ------ > [conda-installs 5/5] RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo "Is torch compiled with cuda: ${IS_CUDA}"; if test "${IS_CUDA}" != "True" -a ! -z "12.4.0"; then exit 1; fi: 1.656 Is torch compiled with cuda: False ------ Dockerfile:80 -------------------- 79 | RUN /opt/conda/bin/pip install torchelastic 80 | >>> RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())');\ 81 | >>> echo "Is torch compiled with cuda: ${IS_CUDA}"; \ 82 | >>> if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ 83 | >>> exit 1; \ 84 | >>> fi 85 | -------------------- ERROR: failed to solve: process "/bin/sh -c IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); echo \"Is torch compiled with cuda: ${IS_CUDA}\"; if test \"${IS_CUDA}\" != \"True\" -a ! -z \"${CUDA_VERSION}\"; then \texit 1; fi" did not complete successfully: exit code: 1 (base) [ec2-user@ip-172-30-2-248 pytorch]$ docker buildx build --progress=plain --platform="linux/amd64" --target official -t ghcr.io/pytorch/pytorch:2.5.0.dev20240617-cuda12.4-cudnn9-devel --build-arg BASE_IMAGE=nvidia/cuda:12.4.0-devel-ubuntu22.04 --build-arg PYTHON_VERSION=3.11 --build-arg CUDA_VERSION= --build-arg CUDA_CHANNEL=nvidia --build-arg PYTORCH_VERSION=2.5.0.dev20240617 --build-arg INSTALL_CHANNEL=pytorch --build-arg TRITON_VERSION= --build-arg CMAKE_VARS="" . #0 building with "default" instance using docker driver ``` Please note looks like we are installing from pytorch rather then nighlty channel on PR hence cuda 12.4 is failing since its not in pytorch channel yet: https://github.com/pytorch/pytorch/actions/runs/9555354734/job/26338476741?pr=128852 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128852 Approved by: https://github.com/malfet --- Dockerfile | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Dockerfile b/Dockerfile index ae88187972ef22..b751c64a8439e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,6 +77,11 @@ RUN case ${TARGETPLATFORM} in \ esac && \ /opt/conda/bin/conda clean -ya RUN /opt/conda/bin/pip install torchelastic +RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); \ + echo "Is torch compiled with cuda: ${IS_CUDA}"; \ + if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ + exit 1; \ + fi FROM ${BASE_IMAGE} as official ARG PYTORCH_VERSION From 8415a4ba98f337e6d21a3c0b026917c03a19e955 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Mon, 17 Jun 2024 22:52:25 +0000 Subject: [PATCH 106/171] Back out "[ROCm] TunableOp for gemm_and_bias (#128143)" (#128815) Summary: Original commit changeset: 35083f04fdae Original Phabricator Diff: D58501726 This PR is bringing a large numerical gap. e.g. for 256 x 4096 x 4096 GEMM, if we enable tunable op + DISABLE_ADDMM_HIP_LT=0, the results are way off. Differential Revision: D58660832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128815 Approved by: https://github.com/mxz297, https://github.com/eqy, https://github.com/malfet --- aten/src/ATen/cuda/tunable/GemmCommon.h | 76 +----------- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 133 ++++----------------- aten/src/ATen/cuda/tunable/Tunable.cpp | 4 +- aten/src/ATen/cuda/tunable/TunableGemm.h | 68 +---------- aten/src/ATen/native/cuda/Blas.cpp | 63 ++-------- 5 files changed, 38 insertions(+), 306 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 64a482bc2781bd..a2c7c734a551f3 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -81,8 +81,7 @@ struct GemmParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k); } size_t GetSize(bool duplicate_inputs) const { @@ -144,73 +143,6 @@ struct GemmParams : OpParams { bool duplicate_inputs_; }; -template -struct GemmAndBiasParams : OpParams { - std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; - } - - size_t GetSize(bool duplicate_inputs) const { - size_t size = sizeof(T) * ldc * n; - if (duplicate_inputs) { - size += sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); - size += sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); - } - return size; - } - - GemmAndBiasParams* DeepCopy(bool duplicate_inputs) const { - GemmAndBiasParams* copy = new GemmAndBiasParams; - *copy = *this; - c10::DeviceIndex device = 0; - AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); - size_t c_size = ldc * n * sizeof(T); - copy->c = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(c_size)); - AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync( - copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true)); - if (duplicate_inputs) { - size_t a_size = sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m); - size_t b_size = sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k); - copy->a = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(a_size)); - copy->b = static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(b_size)); - copy->duplicate_inputs_ = true; - } - return copy; - } - - // only call on object returned by DeepCopy - void Delete() { - c10::cuda::CUDACachingAllocator::raw_delete(c); - if (duplicate_inputs_) { - c10::cuda::CUDACachingAllocator::raw_delete(const_cast(a)); - c10::cuda::CUDACachingAllocator::raw_delete(const_cast(b)); - } - } - - TuningStatus NumericalCheck(GemmAndBiasParams *other) { - auto c_dtype = c10::CppTypeToScalarType::value; - return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL; - } - - char transa; - char transb; - int64_t m; - int64_t n; - int64_t k; - at::opmath_type alpha; - const T* a; - int64_t lda; - const T* b; - int64_t ldb; - T* c; - int64_t ldc; - const T* bias; - at::cuda::blas::GEMMAndBiasActivationEpilogue activation; -private: - bool duplicate_inputs_; -}; - template struct GemmStridedBatchedParams : OpParams { GemmStridedBatchedParams() { @@ -218,8 +150,7 @@ struct GemmStridedBatchedParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); } size_t GetSize(bool duplicate_inputs) const { @@ -292,8 +223,7 @@ struct ScaledGemmParams : OpParams { } std::string Signature() const override { - static std::string val = c10::str(transa, transb, "_", m, "_", n, "_", k); - return val; + return c10::str(transa, transb, "_", m, "_", n, "_", k); } size_t GetSize(bool duplicate_inputs) const { diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index ab1525bef65229..a9c420700275e4 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -25,35 +25,35 @@ namespace at::cuda::tunable { template -constexpr hipblasDatatype_t HipDataTypeFor(); +constexpr hipblasDatatype_t HipBlasDataTypeFor(); template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_32F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_32F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_16F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_16BF; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_16B; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { - return HIP_R_64F; +constexpr hipblasDatatype_t HipBlasDataTypeFor() { + return HIPBLAS_R_64F; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIP_R_8F_E4M3_FNUZ; } template <> -constexpr hipblasDatatype_t HipDataTypeFor() { +constexpr hipblasDatatype_t HipBlasDataTypeFor() { return HIP_R_8F_E5M2_FNUZ; } @@ -62,11 +62,6 @@ int GetBatchFromParams(const GemmParams* params) { return 1; } -template -int GetBatchFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetBatchFromParams(const GemmStridedBatchedParams* params) { return params->batch; @@ -82,11 +77,6 @@ int GetStrideAFromParams(const GemmParams* params) { return 1; } -template -int GetStrideAFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideAFromParams(const GemmStridedBatchedParams* params) { return params->stride_a; @@ -102,11 +92,6 @@ int GetStrideBFromParams(const GemmParams* params) { return 1; } -template -int GetStrideBFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideBFromParams(const GemmStridedBatchedParams* params) { return params->stride_b; @@ -122,11 +107,6 @@ int GetStrideCFromParams(const GemmParams* params) { return 1; } -template -int GetStrideCFromParams(const GemmAndBiasParams* params) { - return 1; -} - template int GetStrideCFromParams(const GemmStridedBatchedParams* params) { return params->stride_c; @@ -142,11 +122,6 @@ float GetAlphaFromParams(const GemmParams* params) { return params->alpha; } -template -float GetAlphaFromParams(const GemmAndBiasParams* params) { - return params->alpha; -} - template float GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; @@ -162,11 +137,6 @@ float GetBetaFromParams(const GemmParams* params) { return params->beta; } -template -float GetBetaFromParams(const GemmAndBiasParams* params) { - return 0.0; -} - template float GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; @@ -182,11 +152,6 @@ const void* GetAScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -202,11 +167,6 @@ const void* GetBScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -222,11 +182,6 @@ const void* GetDScalePointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { - return nullptr; -} - template const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -242,11 +197,6 @@ const void* GetBiasPointerFromParams(const GemmParams* params) { return nullptr; } -template -const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { - return params->bias; -} - template const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; @@ -262,11 +212,6 @@ hipDataType GetBiasTypeFromParams(const GemmParams* params) { return HIP_R_32F; } -template -hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { - return HipDataTypeFor(); -} - template hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { return HIP_R_32F; @@ -277,26 +222,6 @@ hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); } -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { - return params->activation; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - -template -at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { - return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; -} - static hipblasOperation_t _hipblasOpFromChar(char op) { switch (op) { case 'n': @@ -402,9 +327,9 @@ class HipblasltGemmOp : public Callable { TuningStatus Call(const ParamsT* params) override { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); - auto a_datatype = HipDataTypeFor(); - auto b_datatype = HipDataTypeFor(); - auto in_out_datatype = HipDataTypeFor(); + auto a_datatype = HipBlasDataTypeFor(); + auto b_datatype = HipBlasDataTypeFor(); + auto in_out_datatype = HipBlasDataTypeFor(); auto opa = _hipblasOpFromChar(params->transa); auto opb = _hipblasOpFromChar(params->transb); @@ -460,22 +385,13 @@ class HipblasltGemmOp : public Callable { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); - } - const void* bias_ptr = GetBiasPointerFromParams(params); - auto bias_datatype = GetBiasTypeFromParams(params); - if (bias_ptr) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); - auto activation = GetActivationFromParams(params); - if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); - } - else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); - } - else { + const void* bias_ptr = GetBiasPointerFromParams(params); + auto bias_datatype = GetBiasTypeFromParams(params); + if (bias_ptr) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); } } @@ -544,9 +460,9 @@ template (); - auto b_datatype = HipDataTypeFor(); - auto in_out_datatype = HipDataTypeFor(); + auto a_datatype = HipBlasDataTypeFor(); + auto b_datatype = HipBlasDataTypeFor(); + auto in_out_datatype = HipBlasDataTypeFor(); std::vector heuristic_result; hipblasLtHandle_t handle; @@ -589,11 +505,6 @@ auto GetHipBlasLtGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } -template -auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { - return GetHipBlasLtTypeStringAndOps>(); -} - template auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index d3d2333323e7f1..fc27fab77d7907 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -376,8 +376,8 @@ void TuningContext::EnableNumericsCheck(bool value) { bool TuningContext::IsNumericsCheckEnabled() const { static const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK"); - if (env != nullptr && strcmp(env, "1") == 0) { - return true; + if (env != nullptr && strcmp(env, "0") == 0) { + return false; } return numerics_check_enable_; } diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 6b02e26ade4d77..53e6154120c92f 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -48,28 +48,6 @@ class DefaultGemmOp : public Callable> { } }; -static bool _transposeBoolFromChar(char op) { - return op == 't' || op == 'T'; -} - -template -class DefaultGemmAndBiasOp : public Callable> { - public: - TuningStatus Call(const GemmAndBiasParams* params) override { - at::cuda::blas::gemm_and_bias( - _transposeBoolFromChar(params->transa), - _transposeBoolFromChar(params->transb), - params->m, params->n, params->k, - params->alpha, - params->a, params->lda, - params->b, params->ldb, - params->bias, - params->c, params->ldc, - params->activation); - return OK; - } -}; - template class DefaultGemmStridedBatchedOp : public Callable> { public: @@ -287,45 +265,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } std::string Signature() override { - static std::string val = c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; - } -}; - -template -class GemmAndBiasTunableOp : public TunableOp, StreamTimer> { - public: - GemmAndBiasTunableOp() { - this->RegisterOp(std::string("Default"), std::make_unique>()); - - auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); - -#if defined(USE_ROCM) - bool rocm_validators = false; - - static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { - rocm_validators = true; - // disallow tuning of hipblaslt with c10::complex - if constexpr ( - !std::is_same_v> && - !std::is_same_v>) { - for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps()) { - this->RegisterOp(std::move(name), std::move(op)); - } - } - AddHipblasltValidator(); - } - - if (rocm_validators) { - AddRocmValidator(); - } -#endif - } - - std::string Signature() override { - static std::string val = c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; + return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -368,8 +308,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } std::string Signature() override { - static std::string val = c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; + return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -391,12 +330,11 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> } std::string Signature() override { - static std::string val = c10::str("ScaledGemmTunableOp", + return c10::str("ScaledGemmTunableOp", "_", TypeName(AT{}), "_", TypeName(BT{}), "_", TypeName(CT{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); - return val; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 728f210b66ed01..ff8eb60b290ba2 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -175,6 +175,12 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); #ifdef USE_ROCM + // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) + // note the current tunable op is not the hipblaslt path (gemm_and_bias) + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + return true; + } // allow both CUDA and HIP env var names for ROCm builds // also, current default for ROCm builds is disable by default if (env_value == nullptr) { @@ -208,49 +214,6 @@ static bool isSupportedHipLtROCmArch(int index) { } #endif -template -static void launchTunableGemmAndBias(cublasCommonArgs &args, Tensor& result, const Tensor& self, bool is_rocm) { - bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); - bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); - at::cuda::tunable::GemmAndBiasParams params; - params.transa = args.transa; - params.transb = args.transb; - params.m = args.m; - params.n = args.n; - params.k = args.k; - params.a = args.mata->const_data_ptr(); - params.lda = args.lda; - params.b = args.matb->const_data_ptr(); - params.ldb = args.ldb; - if (is_rocm) { - params.bias = (&result != &self) ? self.const_data_ptr() : nullptr; - } - else { - params.bias = self.const_data_ptr(); - } - params.c = args.result->data_ptr(); - params.ldc = args.result_ld; - if (transa_ && transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (transa_ && !transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (!transa_ && transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else if (!transa_ && !transb_) { - static at::cuda::tunable::GemmAndBiasTunableOp gemm{}; - gemm(¶ms); - } - else { - TORCH_CHECK(false, "unreachable"); - } -} - Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call @@ -378,11 +341,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - launchTunableGemmAndBias(args, result, self, true); - } - else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -401,7 +359,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_to_gemm_and_blas_arg(activation) ); - }}); + }); #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -419,11 +377,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type, "addmm_cuda_lt", [&] { - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { - launchTunableGemmAndBias(args, result, self, false); - } - else { at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', @@ -440,7 +393,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma args.result_ld, activation_epilogue ); - }}); + }); #endif } else { From 95b5ea9cdef67d211ec2b1e7242100c7e2fad52a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 13 Jun 2024 12:25:02 -0700 Subject: [PATCH 107/171] Add mark_unbacked (#128638) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/128638 Approved by: https://github.com/IvanKobzarev --- torch/_dynamo/decorators.py | 24 ++++++++++++++++++++++++ torch/_dynamo/variables/builder.py | 5 ++++- torch/fx/experimental/symbolic_shapes.py | 8 ++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index ec25d06281fc0e..79bbb493865c87 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -184,6 +184,30 @@ class directly; instead, use :func:`mark_dynamic`. max: int +@forbid_in_graph +def mark_unbacked(t, index): + """ + Mark a tensor as having an unbacked dim. This changes the semantics of operations, + we will always report the size does not equal zero/one, we will turn asserts + on this index into runtime asserts, and if you try to get the real value we will + raise an exception. In other words, we will treat this dimension as if it was + data dependent (we do not know anything about its value.) + """ + # You could have copied the mark_dynamic behavior but I'm not convinced + # it's what you want + assert not is_traceable_wrapper_subclass(t), "not implemented yet" + + if isinstance(index, int): + if not hasattr(t, "_dynamo_unbacked_indices"): + t._dynamo_unbacked_indices = set() + t._dynamo_unbacked_indices.add(index) + return + + assert isinstance(index, (list, tuple)) + for i in index: + mark_unbacked(t, i) + + @forbid_in_graph def mark_dynamic(t, index, *, min=None, max=None): """ diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f36f53b6537aa0..2097690b88b036 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2211,6 +2211,7 @@ def update_dim2constraint(dim, constraint_range, debug_name): constraint_dims = [] for i in range(e.dim()): # NB: mark dynamic has precedence over static + marked_unbacked = i in getattr(e, "_dynamo_unbacked_indices", set()) marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) marked_static = i in getattr(e, "_dynamo_static_indices", set()) @@ -2262,7 +2263,9 @@ def update_dim2constraint(dim, constraint_range, debug_name): constraint_dims.append(constraint_dim) # Now, figure out if the dim is dynamic/duck/static - if ( + if marked_unbacked: + dynamic = DimDynamic.SIZE_LIKE_UNBACKED + elif ( constraint_dim is not None or marked_dynamic or marked_weak_dynamic diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index fcfe7d9667daf0..29948534084650 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1017,6 +1017,8 @@ class DimDynamic(Enum): DUCK = 1 # Treat the dimension statically based on its hint STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 # NB: These constraints affect both clients and backends: given some @@ -3433,6 +3435,12 @@ def create_symbol( ) -> "sympy.Expr": """Create a new symbol which is tracked by this ShapeEnv """ + if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED: + r = self.create_unbacked_symint().node.expr + self._constrain_range_for_size(r) + # TODO: maybe put the hint somewhere + return r + # check if constraint_dim is actually static integer if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper: dynamic_dim = DimDynamic.STATIC From b70440f0a7ff031decaf994c15474148007b5aa5 Mon Sep 17 00:00:00 2001 From: awayzjj Date: Mon, 17 Jun 2024 23:42:40 +0000 Subject: [PATCH 108/171] Document the torch.cuda.profiler.profile function (#128216) Fixes https://github.com/pytorch/pytorch/issues/127901 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128216 Approved by: https://github.com/malfet, https://github.com/eqy --- torch/cuda/profiler.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch/cuda/profiler.py b/torch/cuda/profiler.py index f95aae0f85a7d0..65269414f55a3a 100644 --- a/torch/cuda/profiler.py +++ b/torch/cuda/profiler.py @@ -65,6 +65,18 @@ def stop(): @contextlib.contextmanager def profile(): + """ + Enable profiling. + + Context Manager to enabling profile collection by the active profiling tool from CUDA backend. + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> model = torch.nn.Linear(20, 30).cuda() + >>> inputs = torch.randn(128, 20).cuda() + >>> with torch.cuda.profiler.profile() as prof: + ... model(inputs) + """ try: start() yield From 11ff5345d249c27950a06a347cc70aa0047dd46e Mon Sep 17 00:00:00 2001 From: chilli Date: Mon, 17 Jun 2024 12:27:40 -0700 Subject: [PATCH 109/171] Changed colored logging to only be turned on if printing to interactive terminal (#128874) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128874 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 10 +--------- torch/fx/_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index c47552fc1b2a76..128b1fbbe4ecc4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -49,12 +49,7 @@ unsupported, xfailIfPy312, ) -from torch._dynamo.utils import ( - CompileProfiler, - counters, - ifdynstaticdefault, - strip_color_from_string, -) +from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -748,7 +743,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) # Check the graph under static shapes if torch._dynamo.config.assume_static_by_default: @@ -811,7 +805,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) self.assertExpectedInline( post_grad_graphs, """\ @@ -904,7 +897,6 @@ def f(x, y, z, n): post_grad_graphs = "\n".join( log_stream.getvalue().strip().split("\n")[3:] ).strip() - post_grad_graphs = strip_color_from_string(post_grad_graphs) self.assertExpectedInline( post_grad_graphs, """\ diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 36c831dfdee069..b27e1df5539183 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import sys from typing import Dict, Optional import torch @@ -20,6 +21,9 @@ def format_name(): if "print_output" not in kwargs: kwargs["print_output"] = False + if "colored" in kwargs and not sys.stdout.isatty(): + kwargs["colored"] = False + return LazyString( lambda: _format_graph_code( f"===== {format_name()} =====\n", From beb29836cd1e5b30df8c5a3c1122c926ef4021bc Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 110/171] [Inductor][CPP] Add Min/Max with VecMask (#126841) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/126824 which is missing the support of `min/max` with `VecMask`. **TestPlan** ``` python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_max_cpu_bool python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_clamp_min_cpu_bool ``` Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126841 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 2 -- torch/_inductor/codegen/cpp.py | 31 +++++++++++++++++----- torch/_inductor/codegen/cpp_utils.py | 26 +++++++++++++++--- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 7998a3aff58d6a..8c85e731e98c40 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -413,8 +413,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "addmv": {f16}, "argsort": {b8, f16, f32, f64, i32, i64}, "as_strided.partial_views": {f16}, - "clamp_max": {b8}, - "clamp_min": {b8}, "corrcoef": {f16}, "diff": {f16}, "einsum": {f16, i32}, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 6b8574b9268ad7..0b6dca76526519 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -64,7 +64,14 @@ OptimizationContext, ) -from .cpp_utils import cexpr, cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, value_to_cpp +from .cpp_utils import ( + cexpr, + cexpr_index, + DTYPE_TO_CPP, + INDEX_TYPE, + unify_mask_base_type, + value_to_cpp, +) schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") @@ -1311,11 +1318,21 @@ def truncdiv(a, b): @staticmethod def minimum(a, b): - return f"at::vec::minimum({a}, {b})" + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"at::vec::maximum({a}, {b})" + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" @staticmethod def square(a): @@ -1326,10 +1343,10 @@ def where(a, b, c): assert isinstance(V.kernel, CppVecKernel) if b.dtype == torch.bool: assert c.dtype == torch.bool - blendv_a = f"{V.kernel._get_mask_cast(a, torch.float)}" - blendv_b = f"{V.kernel._get_mask_cast(b, torch.float)}" - blendv_c = f"{V.kernel._get_mask_cast(c, torch.float)}" - return f"decltype({b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" else: return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 336837328a0e5a..66f2dfb54aac09 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -4,7 +4,7 @@ import math from collections import namedtuple -from typing import Dict, List +from typing import Dict, List, Tuple from unittest.mock import patch import sympy @@ -12,10 +12,11 @@ import torch from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import ir -from ..utils import sympy_index_symbol_with_prefix +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix from ..virtualized import V -from .common import ExprPrinter, Kernel +from .common import CSEVariable, ExprPrinter, Kernel + DTYPE_TO_CPP = { torch.float32: "float", @@ -421,3 +422,22 @@ def inner(index): return inner return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: Tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars From c35ffaf954ffdfc76aac24e9c503fb0e5d190722 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 111/171] [Inductor][CPP] Add ne with VecMask (#126940) **Summary** Fix https://github.com/pytorch/pytorch/issues/126824#issuecomment-2125039161 which is missing the support of `ne` with `VecMask`. **Test Plan** ``` python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_ne_cpu_bool ``` Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126940 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 ghstack dependencies: #126841 --- aten/src/ATen/cpu/vec/vec_mask.h | 1 + test/inductor/test_torchinductor_opinfo.py | 1 - torch/_inductor/codegen/cpp.py | 9 +++++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index 6b773c40ca8c9c..ebec8d4a3e3c5d 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -259,6 +259,7 @@ VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) +VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 8c85e731e98c40..5f97c2f0fd7121 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -431,7 +431,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "maximum": {b8}, "min.binary": {b8}, "minimum": {b8}, - "ne": {b8}, "new_empty_strided": {f16}, "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 0b6dca76526519..2c800b41adff50 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1114,8 +1114,13 @@ def eq(x, y): def ne(x, y): assert isinstance(V.kernel, CppVecKernel) assert isinstance(x, CppCSEVariable) - assert x.dtype is not None - return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" @staticmethod def lt(x, y): From fbc7559ceb372d88b55c96ef6984accbaa0ec3ec Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 18 Jun 2024 00:55:48 +0000 Subject: [PATCH 112/171] [custom ops] convert string type annotation to real type (#128809) Fixes #105157 Bug source: `from __future__ import annotations` converts type annotation to strings to make forwards references easier. However, existing custom ops do not consider strings to be valid types. Fix: We check if the argument and return type annotation is string type. If so, we try to use `eval` to convert it to a type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128809 Approved by: https://github.com/zou3519 --- .../test_infer_schema_annotation.py | 207 ++++++++++++++++++ torch/_library/infer_schema.py | 28 ++- 2 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 test/custom_operator/test_infer_schema_annotation.py diff --git a/test/custom_operator/test_infer_schema_annotation.py b/test/custom_operator/test_infer_schema_annotation.py new file mode 100644 index 00000000000000..9de44224f1c039 --- /dev/null +++ b/test/custom_operator/test_infer_schema_annotation.py @@ -0,0 +1,207 @@ +# Owner(s): ["module: pt2-dispatcher"] +from __future__ import annotations + +import typing +from typing import List, Optional, Sequence, Union # noqa: F401 + +import torch +import torch._custom_op.impl +from torch import Tensor, types +from torch.testing._internal.common_utils import run_tests, TestCase + + +mutates_args = {} + + +class TestInferSchemaWithAnnotation(TestCase): + def test_tensor(self): + def foo_op(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.clone() + y + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(Tensor x, Tensor y) -> Tensor") + + def test_native_types(self): + def foo_op(x: int) -> int: + return x + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(SymInt x) -> SymInt") + + def foo_op_2(x: bool) -> bool: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(bool x) -> bool") + + def foo_op_3(x: str) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(str x) -> SymInt") + + def foo_op_4(x: float) -> float: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(float x) -> float") + + def test_torch_types(self): + def foo_op_1(x: torch.types.Number) -> torch.types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_2(x: torch.dtype) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(ScalarType x) -> SymInt") + + def foo_op_3(x: torch.device) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(Device x) -> SymInt") + + def test_type_variants(self): + def foo_op_1(x: typing.Optional[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_1, mutates_args) + self.assertEqual(result, "(SymInt? x) -> SymInt") + + def foo_op_2(x: typing.Sequence[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_3(x: typing.List[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_4(x: typing.Optional[typing.Sequence[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_5(x: typing.Optional[typing.List[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_6(x: typing.Union[int, float, bool]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_7(x: typing.Union[int, bool, float]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def test_no_library_prefix(self): + def foo_op(x: Tensor) -> Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_2(x: Tensor) -> torch.Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_3(x: torch.Tensor) -> Tensor: + return x.clone() + + result = torch._custom_op.impl.infer_schema(foo_op_3, mutates_args) + self.assertEqual(result, "(Tensor x) -> Tensor") + + def foo_op_4(x: List[int]) -> types.Number: + return x[0] + + result = torch._custom_op.impl.infer_schema(foo_op_4, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> Scalar") + + def foo_op_5(x: Optional[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_5, mutates_args) + self.assertEqual(result, "(SymInt? x) -> SymInt") + + def foo_op_6(x: Sequence[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_6, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_7(x: List[int]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_7, mutates_args) + self.assertEqual(result, "(SymInt[] x) -> SymInt") + + def foo_op_8(x: Optional[Sequence[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_8, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_9(x: Optional[List[int]]) -> int: + return 1 + + result = torch._custom_op.impl.infer_schema(foo_op_9, mutates_args) + self.assertEqual(result, "(SymInt[]? x) -> SymInt") + + def foo_op_10(x: Union[int, float, bool]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_10, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def foo_op_11(x: Union[int, bool, float]) -> types.Number: + return x + + result = torch._custom_op.impl.infer_schema(foo_op_11, mutates_args) + self.assertEqual(result, "(Scalar x) -> Scalar") + + def test_unsupported_annotation(self): + with self.assertRaisesRegex( + ValueError, + r"Unsupported type annotation D. It is not a type.", + ): + + def foo_op(x: D) -> Tensor: # noqa: F821 + return torch.Tensor(x) + + torch._custom_op.impl.infer_schema(foo_op, mutates_args) + + with self.assertRaisesRegex( + ValueError, + r"Unsupported type annotation E. It is not a type.", + ): + + def foo_op_2(x: Tensor) -> E: # noqa: F821 + return x + + torch._custom_op.impl.infer_schema(foo_op_2, mutates_args) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 6305375e4433d3..c4f7b8ee51e6c2 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs import inspect import typing +from typing import List, Optional, Sequence, Union # noqa: F401 +import torch # noqa: F401 from .. import device, dtype, Tensor, types @@ -12,6 +14,9 @@ def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str: write custom ops in real life: - none of the outputs alias any of the inputs or each other. - only the args listed in mutates_args are being mutated. + - string type annotations "device, dtype, Tensor, types" without library specification + are assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + without library specification are assumed to be typing.*. Callers (e.g. the custom ops API) are responsible for checking these assumptions. """ @@ -22,6 +27,14 @@ def error_fn(what): f"infer_schema(func): {what} " f"Got func with signature {sig})" ) + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception as e: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + params = [] seen_args = set() saw_kwarg_only_arg = False @@ -38,13 +51,19 @@ def error_fn(what): if param.annotation is inspect.Parameter.empty: error_fn(f"Parameter {name} must have a type annotation.") - if param.annotation not in SUPPORTED_PARAM_TYPES.keys(): + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = param.annotation + if type(annotation_type) == str: + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): error_fn( f"Parameter {name} has unsupported type {param.annotation}. " f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." ) - schema_type = SUPPORTED_PARAM_TYPES[param.annotation] + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] if name in mutates_args: if not schema_type.startswith("Tensor"): error_fn( @@ -72,7 +91,10 @@ def error_fn(what): f"mutates_args should contain the names of all args that the " f"custom op mutates." ) - ret = parse_return(sig.return_annotation, error_fn) + return_annotation = sig.return_annotation + if type(return_annotation) == str: + return_annotation = convert_type_string(return_annotation) + ret = parse_return(return_annotation, error_fn) return f"({', '.join(params)}) -> {ret}" From 9e8443b56f5a83877803be3ba43f7941841904c9 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 18 Jun 2024 01:26:45 +0000 Subject: [PATCH 113/171] Remove dtype from gpt-fast micro benchmark experiments model name (#128789) Per comments on https://github.com/pytorch/test-infra/pull/5344, we already have a dtype column with the same information Pull Request resolved: https://github.com/pytorch/pytorch/pull/128789 Approved by: https://github.com/yanboliang --- benchmarks/gpt_fast/benchmark.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 16f3e55af17b04..1c09fe03a904a1 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -76,7 +76,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"mlp_layer_norm_gelu_{dtype_str}", + "mlp_layer_norm_gelu", "flops_utilization", expected_flops_utilization, f"{flops_utilization:.02f}", @@ -113,7 +113,7 @@ def run_layer_norm(device: str = "cuda"): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"layer_norm_{dtype_str}", + "layer_norm", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", @@ -156,7 +156,7 @@ def gather_gemv(W, score_idxs, x): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"gather_gemv_{dtype_str}", + "gather_gemv", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", @@ -197,7 +197,7 @@ def gemv(W, x): dtype_str = str(dtype).replace("torch.", "") results.append( Experiment( - f"gemv_{dtype_str}", + "gemv", "memory_bandwidth(GB/s)", expected_memory_bandwidth, f"{memory_bandwidth:.02f}", From e12fa93b8bb3b7b7148f6111577e454bd3251223 Mon Sep 17 00:00:00 2001 From: Fuzzkatt Date: Tue, 18 Jun 2024 02:00:01 +0000 Subject: [PATCH 114/171] add is_big_gpu(0) check to test_select_algorithm tests in tests/inductor/test_cuda_cpp_wrapper.py (#128652) In NVIDIA internal CI, on Jetson devices we are seeing this failure for `python test/inductor/test_cuda_cpp_wrapper.py -k test_addmm_cuda_cuda_wrapper -k test_linear_relu_cuda_cuda_wrapper`: ``` /usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:132: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( W0613 20:57:17.722000 281473279256672 torch/_inductor/utils.py:902] [0/0] Not enough SMs to use max_autotune_gemm mode frames [('total', 1), ('ok', 1)] stats [('calls_captured', 2), ('unique_graphs', 1)] inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1), ('pattern_matcher_count', 1), ('pattern_matcher_nodes', 1)] aot_autograd [('total', 1), ('ok', 1)] F ====================================================================== FAIL: test_linear_relu_cuda_cuda_wrapper (__main__.TestCudaWrapper) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 2759, in wrapper method(*args, **kwargs) File "/opt/pytorch/pytorch/test/inductor/test_torchinductor.py", line 9818, in new_test return value(self) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/pytorch/pytorch/test/inductor/test_cuda_cpp_wrapper.py", line 152, in fn _, code = test_torchinductor.run_and_get_cpp_code( File "/opt/pytorch/pytorch/test/inductor/test_torchinductor.py", line 356, in run_and_get_cpp_code result = fn(*args, **kwargs) File "/opt/pytorch/pytorch/test/inductor/test_select_algorithm.py", line 43, in wrapped return fn(*args, **kwargs) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/usr/lib/python3.10/unittest/mock.py", line 1379, in patched return func(*newargs, **newkeywargs) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/usr/lib/python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "/opt/pytorch/pytorch/test/inductor/test_select_algorithm.py", line 62, in test_linear_relu_cuda self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 3642, in assertEqual raise error_metas.pop()[0].to_error( AssertionError: Scalars are not equal! Expected 1 but got 0. Absolute difference: 1 Relative difference: 1.0 ``` Looking into it, we see the failure is from https://github.com/pytorch/pytorch/blob/main/test/inductor/test_select_algorithm.py#L62. The warning `W0613 20:57:17.722000 281473279256672 torch/_inductor/utils.py:902] [0/0] Not enough SMs to use max_autotune_gemm ` is triggered from https://github.com/pytorch/pytorch/blob/main/torch/_inductor/utils.py#L973. Printing torch.cuda.get_device_properties(0).multi_processor_count returns 16 on the computelab AGX Orin; thus it makes sense that this check is failing, since the min_required_sms is 68, thus not letting it pick the autotune algorithm. Looking at the main for test_select_algorithm.py, we see that these tests should only be run if is_big_gpu(0) is true: https://github.com/pytorch/pytorch/blob/main/test/inductor/test_select_algorithm.py#L344. Thus this PR adds a similar check to the invocation of these tests in test_cuda_cpp_wrapper.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128652 Approved by: https://github.com/soulitzer, https://github.com/eqy --- test/inductor/test_cuda_cpp_wrapper.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index eaa0134be8f09b..1289de27436595 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -194,14 +194,6 @@ class BaseTest(NamedTuple): "test_cat_slice_cat", tests=test_pattern_matcher.TestPatternMatcher(), ), - BaseTest( - "test_addmm", - tests=test_select_algorithm.TestSelectAlgorithm(), - ), - BaseTest( - "test_linear_relu", - tests=test_select_algorithm.TestSelectAlgorithm(), - ), # TODO: Re-enable this test after fixing cuda wrapper for conv Triton templates with dynamic shapes. # This test is unstable: it succeeds when an ATEN kernel is used, and fails when a Triton kernel is used. # Currently it passes on CI (an ATEN kernel is chosen) and fails locally (a Triton kernel is chosen). @@ -226,6 +218,21 @@ class BaseTest(NamedTuple): ]: make_test_case(item.name, item.device, item.tests) + from torch._inductor.utils import is_big_gpu + + if is_big_gpu(0): + for item in [ + BaseTest( + "test_addmm", + tests=test_select_algorithm.TestSelectAlgorithm(), + ), + BaseTest( + "test_linear_relu", + tests=test_select_algorithm.TestSelectAlgorithm(), + ), + ]: + make_test_case(item.name, item.device, item.tests) + test_torchinductor.copy_tests( CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper ) From 43998711a794b6c324a59397ded048786e0e9312 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Jun 2024 02:07:03 +0000 Subject: [PATCH 115/171] [CUDAGraph] add more docs for cudagraph trees (#127963) This PR adds more documentation for CUDAGraph Trees, including - Iteration Support - Input Mutation Support - Dynamic Shape Support - NCCL Support - Reasons for Skipping CUDAGraph Pull Request resolved: https://github.com/pytorch/pytorch/pull/127963 Approved by: https://github.com/eellison --- .../source/torch.compiler_cudagraph_trees.rst | 192 +++++++++++++++++- 1 file changed, 188 insertions(+), 4 deletions(-) diff --git a/docs/source/torch.compiler_cudagraph_trees.rst b/docs/source/torch.compiler_cudagraph_trees.rst index b1986dc0dc47fd..360fbf0c5d9ce5 100644 --- a/docs/source/torch.compiler_cudagraph_trees.rst +++ b/docs/source/torch.compiler_cudagraph_trees.rst @@ -1,7 +1,10 @@ CUDAGraph Trees ================ -CUDAGraph Background +**Background** +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +CUDAGraph -------------------- For a longer background on CUDAGraphs, read `accelerating pytorch with CUDAGraphs `_. @@ -35,8 +38,8 @@ TorchDynamo Previous CUDA Graphs Integration Running with ``cudagraph_trees=False`` does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks, this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward. -CUDAGraph Trees Integration ---------------------------- +**CUDAGraph Trees Integration** +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of CUDA Graph captures. Let’s take a look at an illustrative example: @@ -90,6 +93,181 @@ The second time we hit graph 3 we are warmed up and ready to record. We record g \\ \\ 4 4 + +Input Mutation Support +---------------------- + +Input mutation function refers to a function conducting in-place writes to an input tensor, +as illustrated below: + +.. code-block:: python + + def foo(x, y): + # mutates input x + x.add_(1) + return x + y + +Input mutation functions generally lead to challenges for CUDAGraph Trees. Due to the static +CUDA memory address requirement from CUDAGraph, for each input tensor x, CUDAGraph Trees may +allocate a static memory address x'. During execution, CUDAGraph Trees first copy the input +tensor x to the static memory address x', and then replay the recorded CUDAGraph. For input +mutation function, x' is in-place updated, which is not reflected on the input tensor x since +x and x' reside on different CUDA memory addresses. + +A closer look at input mutation functions reveals that there are three types of inputs: + +* **inputs from eager**: These tensors we assume will vary input tensor addresses from + execution to execution. Because cudagraphs freeze memory addresses, we need to copy these + inputs to a static address tensor prior to graph recording and execution. +* **Parameters and buffers**: These tensors we assume (and runtime-check) have the same tensor + addresses on every execution. We do not need to copy over their contents because the recorded + memory address will be the same as the executed memory address. +* **Tensors which are prior outputs from CUDAGraph Trees**: Because the output tensor addresses + of a cudagraph are fixed, if we run CUDAGraph1, then run CUDAGraph2, the inputs which came from + CUDAGraph1 into CUDAGraph2 will have a fixed memory address. These inputs, like parameters and + buffers, do not require copying over to a static address tensor. We check to make sure that + these inputs are stable at runtime, and if they're not we will re-record. + +CUDAGraph Trees support input mutation on parameters and buffers, and tensors which are prior +outputs from CUDAGraph Trees. For mutation on inputs from eager, CUDAGraph Trees will run the +function without CUDAGraph and emit *skipping due to mutated inputs* log. The following example +shows CUDAGraph Trees' support for tensors which are prior outputs from CUDAGraph Trees. + + +.. code-block:: python + + import torch + + @torch.compile(mode="reduce-overhead") + def foo(x): + return x + 1 + + @torch.compile(mode="reduce-overhead") + def mut(x): + return x.add_(2) + + # Enable input mutation support + torch._inductor.config.triton.cudagraph_support_input_mutation = True + + for i in range(3): + torch.compiler.cudagraph_mark_step_begin() + inp = torch.rand([4], device="cuda") + + # CUDAGraph is applied since `foo` does not mutate `inp` + tmp = foo(inp) + # Although `mut` mutates `tmp`, which is an output of a CUDAGraph + # managed function. So CUDAGraph is still applied. + mut(tmp) + + + torch.compiler.cudagraph_mark_step_begin() + inp = torch.rand([4], device="cuda") + + tmp = foo(inp) + # While `tmp` is a CUDAGraph Tree managed function's output, `tmp.clone()` + # is not. So CUDAGraph is not applied to `mut` and there is a log + # `skipping cudagraphs due to mutated inputs` + mut(tmp.clone()) + + +To enable CUDAGraph Trees for a function mutating inputs from eager, please re-write +the function to avoid input mutation. + +.. note:: Enable input mutation support by setting + `torch._inductor.config.cudagraph_support_input_mutation = True `_ + for "reduce-overhead" mode. + + +Dynamic Shape Support +--------------------- + +`Dynamic shape `_ +means that an input tensor has different shapes across function calls. Since CUDAGraph +requires fixed tensor addresses, CUDAGraph Trees re-record CUDAGraph for every unique +shape of an input tensor. This leads to multiple CUDAGraphs for a single inductor graph. +When there are limited shapes (e.g., batch sizes in inference), it is profitable to +re-record CUDAGraphs. However, if input tensor shapes change frequently or even on +every invocation, re-recording CUDAGraph may not be profitable. Nvidia uses 64 KB of +device memory per kernel launch in CUDAGraph, up until CUDA 12.4 and Driver Version 550+. +This memory cost can be significant with many CUDAGraph re-recordings. + +For functions with frequently changing input tensor shapes, we suggest padding input +tensors to a few fixed tensor shapes to still enjoy benefits from CUDAGraph. In addition, +setting `torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True `_ +allows to skip cudagraphing functions with dynamic shape inputs and only cudagraphing +functions with static input tensor shapes. + + +NCCL Support +------------ + +CUDAGraph Trees support functions with nccl operators. While CUDAGraph Trees perform per-device +record for CUDAGraph, NCCL support allows cross-device communication. + +.. code-block:: python + + @torch.compile(mode="reduce-overhead") + def func(x): + y = x * x + y = torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM) + x = torch.nn.functional.silu(x) + return x * y + + +Reasons for Skipping CUDAGraph +------------------------------ + +Since CUDAGraph has requirements such as static input tensor addresses and not supporting +CPU operators, CUDAGraph Trees check whether a function satisfies these requirements and +may skip CUDAGraph when necessary. Here, we list common reasons for skipping CUDAGraph. + +* **Input mutation**: CUDAGraph Trees skip functions that in-place mutates eager input. + In-place mutating parameters and buffers, or output tensors from CUDAGraph Tree managed + functions are still supported. Please see *Input Mutation Support* section for more details. +* **CPU operators**: Functions containing CPU operator are skipped. Please split the + function into multiple functions and apply CUDAGraph Trees on functions with only GPU operators. +* **Multi-device operators**: A function is skipped if it contains operators on multiple + devices. Currently, CUDAGraph is applied on a per-device basis. Please use supported + libraries such as NCCL for cross-device communication. Please see *NCCL Support* + section for more details. +* **Free unbacked symbols**: Free unbacked symbols usually happen during + `dynamic shapes `_. + CUDAGraph Trees currently record a CUDAGraph for every unique input tensor shapes. + Please see *Dynamic Shape Support* for more details. +* **Incompatible operators**: CUDAGraph Trees skip a function if it contain incompatible + operators. Please replace these operators in a function with supported operators. We + show an exhaustive list of incompatible operators: + + +.. code-block:: python + + aten._fused_moving_avg_obs_fq_helper.default + aten._fused_moving_avg_obs_fq_helper_functional.default + aten.multinomial.default + fbgemm.dense_to_jagged.default + fbgemm.jagged_to_padded_dense.default + run_and_save_rng_state + run_with_rng_state + aten._local_scalar_dense + aten._assert_scalar + + +The following operators are incompatible when `torch.are_deterministic_algorithms_enabled() `_. + + +.. code-block:: python + + aten._fused_moving_avg_obs_fq_helper.default + aten._fused_moving_avg_obs_fq_helper_functional.default + aten.multinomial.default + fbgemm.dense_to_jagged.default + fbgemm.jagged_to_padded_dense.default + run_and_save_rng_state + run_with_rng_state + aten._local_scalar_dense + aten._assert_scalar + + Limitations ----------- @@ -112,8 +290,14 @@ Let’s say we are benchmarking running inference with the following code: print(y1) # RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. +In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDAGraph +Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want +to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for +torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics +are wrong, you can mark the start of a new iteration with +`torch.compiler.mark_step_begin() `_, or clone +tensors of a prior iteration (outside of torch.compile) before you begin the next run. -In the Separate CUDA Graph implementation, the output from the first invocation will be overwritten by the second invocation. In CUDA Graph Trees, we don’t want to add unintended dependencies between iterations that would cause us to not hit the hot path, nor do we want we want to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics are wrong, you can mark the start of a new iteration with torch.compiler.mark_step_begin(), or clone tensors of a prior iteration (outside of torch.compile) before you begin the next run. Comparisons ----------- From 22f1793c0ac644a357ee44ccaa78e1252731f57e Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:43:18 -0700 Subject: [PATCH 116/171] [dynamo][easy] Use LazyVariableTracker for UserDefinedObject var_getattr (#128877) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128877 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748 --- torch/_dynamo/variables/user_defined.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2b97d921b73b11..fb2b3c1b6ac4f2 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -941,8 +941,7 @@ def var_getattr(self, tx, name): ) ): if source: - install_guard(source.make_guard(GuardBuilder.HASATTR)) - return VariableBuilder(tx, source)(subobj) + return variables.LazyVariableTracker.create(subobj, source) elif ConstantVariable.is_literal(subobj): return ConstantVariable.create(subobj) elif ( From 4e97d37fd947236333d5ccb37c9d9382878b4003 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:43:21 -0700 Subject: [PATCH 117/171] [inlining-inbuilt-nn-modules][pre-grad] Adjust efficient_conv_bn_eval_graph for inlining (#128878) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128878 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748, #128877 --- .../fx_passes/efficient_conv_bn_eval.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py index 7aecc3f15f33d8..c8165a1a3926a6 100644 --- a/torch/_inductor/fx_passes/efficient_conv_bn_eval.py +++ b/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -139,6 +139,97 @@ def efficient_conv_bn_eval_decomposed( return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) + + return + + @register_graph_pattern( CallFunctionVarArgs( [ From c017c97333dfb9d17f2e5357980241827e50e8d5 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 12:47:25 -0700 Subject: [PATCH 118/171] [dynamo][inlining-inbuilt-nn-modules] Update test output (#128880) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128880 Approved by: https://github.com/mlazos ghstack dependencies: #128315, #128748, #128877, #128878 --- test/dynamo/test_structured_trace.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index f84e08b8f9cce8..e3a82921a838bc 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -379,18 +379,50 @@ def forward(self, x): {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_self_layers_0_weight": [1024, 1024], "l_self_layers_0_bias": [1024], "l_x_": [1024, 1024], "l_self_layers_1_weight": [1024, 1024], "l_self_layers_1_bias": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 3, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 8, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 3, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 8, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 4, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 9, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 4, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 9, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"dynamo_output_graph": {"sizes": {"l_self_modules_layers_modules_0_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_0_parameters_bias_": [1024], "l_x_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_weight_": [1024, 1024], "l_self_modules_layers_modules_1_parameters_bias_": [1024], "input_1": [1024, 1024], "input_2": [1024, 1024]}}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 1, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['self']._modules['layers']._modules['0']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 2, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 31, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 31, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 17, "describer_id": "ID", "size": 4096}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 32, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 32, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_hash", "encoding": "json"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_guards": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} From 4061b3b8225f522ae0ed6db00111441e7d3cc3d5 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Mon, 17 Jun 2024 17:06:46 -0400 Subject: [PATCH 119/171] Forward fix to skip ROCm tests for #122836 (#128891) Fixes broken ROCm tests from #122836. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128891 Approved by: https://github.com/huydhn ghstack dependencies: #127007, #128057, #122836 --- test/test_nestedtensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 86f58b5a0de3a0..50d6deea92911e 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5470,6 +5470,7 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): # shape (B, *, D) nt = random_nt_from_dims( @@ -5500,6 +5501,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 @@ -5536,6 +5538,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): # shape (B, *, D) # min seq len: 7 @@ -5572,6 +5575,7 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 From 17abbafdfc6935bcc133e5f43ba32d44914fe316 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 18 Jun 2024 03:25:20 +0000 Subject: [PATCH 120/171] [inductor] Fix some windows cpp builder issue (#128765) 1. fix some Windows build args. 2. fix c++20 likely issue on Windows, reference: https://github.com/pytorch/pytorch/pull/124997. 3. remove compiler return value check, different compilers return variant value, let's check exception to catch error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128765 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/codecache.py | 7 ++++--- torch/_inductor/cpp_builder.py | 9 +++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 3d265f181b159a..422728a9e59ae2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1376,8 +1376,6 @@ def __bool__(self) -> bool: output_path = x86_isa_help_builder.get_target_file_path() if not os.path.isfile(output_path): status, target_file = x86_isa_help_builder.build() - if status: - return False # Check build result subprocess.check_call( @@ -2573,11 +2571,14 @@ class CppPythonBindingsCodeCache(CppCodeCache): #ifndef _MSC_VER #if __cplusplus < 202002L - // C++20 earlier code + // C++20 (earlier) code // https://en.cppreference.com/w/cpp/language/attributes/likely #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) #endif // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index f75f079d72db2b..a574b334734250 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -345,7 +345,11 @@ def _get_optimization_cflags() -> List[str]: def _get_shared_cflag(compile_only: bool) -> List[str]: if _IS_WINDOWS: - SHARED_FLAG = ["DLL"] + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + SHARED_FLAG = ["DLL", "MD"] else: if compile_only: return ["fPIC"] @@ -567,7 +571,7 @@ def _get_torch_related_args(include_pytorch: bool, aot_mode: bool): ] libraries_dirs = [TORCH_LIB_PATH] libraries = [] - if sys.platform == "linux" and not config.is_fbcode(): + if sys.platform != "darwin" and not config.is_fbcode(): libraries = ["torch", "torch_cpu"] if not aot_mode: libraries.append("torch_python") @@ -663,6 +667,7 @@ def _get_openmp_args(cpp_compiler): # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL libs = [] else: if config.is_fbcode(): From 59b4983dc06f12eded69ab1471817c67c1fc72c0 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 18 Jun 2024 03:40:14 +0000 Subject: [PATCH 121/171] DebugPlane: add dump_traceback handler (#128904) This adds a `dump_traceback` handler so you can see all running threads for a job. This uses a temporary file as a buffer when calling `faulthandler.dump_traceback` and requires the GIL to be held during dumping. Test plan: ``` python test/distributed/elastic/test_control_plane.py -v -k traceback ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128904 Approved by: https://github.com/c-p-i-o --- build_variables.bzl | 1 + .../distributed/elastic/test_control_plane.py | 6 +++ .../c10d/control_plane/PythonHandlers.cpp | 44 +++++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp diff --git a/build_variables.bzl b/build_variables.bzl index b4b4d1ab139cd9..793b611a0a6f07 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -927,6 +927,7 @@ libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + "torch/csrc/distributed/rpc/unpickled_python_call.cpp", "torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp", "torch/csrc/jit/runtime/register_distributed_ops.cpp", + "torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp", ] def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 775b062451b163..7d01bd9eb03005 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -92,6 +92,12 @@ def test_tcp(self) -> None: server.shutdown() + def test_dump_traceback(self) -> None: + with local_worker_server() as pool: + resp = pool.request("POST", "/handler/dump_traceback") + self.assertEqual(resp.status, 200) + self.assertIn(b"in test_dump_traceback\n", resp.data) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp b/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp new file mode 100644 index 00000000000000..cc1539a9527b4b --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/PythonHandlers.cpp @@ -0,0 +1,44 @@ +#include + +#include +#include +#include + +#include +#include +#include + +namespace c10d::control_plane { +namespace { + +RegisterHandler tracebackHandler{ + "dump_traceback", + [](const Request&, Response& res) { + auto tmpfile = c10::make_tempfile("torch-dump_traceback"); + + auto cfile = ::fopen(tmpfile.name.c_str(), "w"); + if (!cfile) { + throw std::runtime_error("failed to open file for writing"); + } + + { + py::gil_scoped_acquire guard{}; + + auto faulthandler = py::module::import("faulthandler"); + faulthandler.attr("dump_traceback")(fileno(cfile), true); + } + + ::fclose(cfile); + + std::ifstream file(tmpfile.name); + std::string str; + std::string file_contents; + while (std::getline(file, str)) { + file_contents += str; + file_contents.push_back('\n'); + } + + res.setContent(std::move(file_contents), "text/plain"); + }}; +} +} // namespace c10d::control_plane From d9eaa224f2512639e55cb11b372fcd1983d22ea5 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Tue, 18 Jun 2024 03:44:38 +0000 Subject: [PATCH 122/171] Fixes #128429: NaN in triu op on MPS (#128575) Fixes triu op when k > 0 and the lower triangle of the input tensor contains inf leading to NaNs in the computation through complement. Fixed by using select API instead. Fixes #128429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128575 Approved by: https://github.com/kulinseth --- .../ATen/native/mps/operations/TriangularOps.mm | 15 ++++++++++----- test/test_mps.py | 8 ++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index 5fa0b221845353..dcea978655b85e 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -35,11 +35,16 @@ if (k > 0) { MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32]; - MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor - numLowerTensor:minusOneTensor - numUpperTensor:diagMinusOneTensor - name:nil]; - outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:complementTensor name:nil]; + MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt32]; + onesTensor = [mpsGraph broadcastTensor:onesTensor toShape:inputTensor.shape name:nil]; + MPSGraphTensor* maskTensor = [mpsGraph bandPartWithTensor:onesTensor + numLowerTensor:minusOneTensor + numUpperTensor:diagMinusOneTensor + name:nil]; + outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor + truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType] + falsePredicateTensor:inputTensor + name:nil]; } else { MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32]; outputTensor = [mpsGraph bandPartWithTensor:inputTensor diff --git a/test/test_mps.py b/test/test_mps.py index 275013f20effcb..311cf8245c4f3a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1617,6 +1617,14 @@ def test_exp(self, device="mps", dtype=torch.float): a = torch.tensor(v, dtype=dtype, device="mps") * b self.compare_with_numpy(torch.exp, np.exp, a) + def test_triu_inf(self, device="mps", dtype=torch.float): + for diag in [-1, 0, 1]: + mask = torch.full((3, 6, 6), float("-inf")) + mask_mps = mask.clone().detach().to('mps') + cpu_ref = torch.triu(mask, diagonal=diag) + mps_out = torch.triu(mask_mps, diagonal=diag) + self.assertEqual(cpu_ref, mps_out) + def test_exp1(self, device="mps", dtype=torch.float): input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype) output = torch.exp(input) From f7eae279463b719c5f25587aac225bd2be891373 Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Mon, 17 Jun 2024 10:06:14 -0700 Subject: [PATCH 123/171] Pass params to dump_nccl_trace_pickle (#128781) Summary Pass parameters from request to dump_nccl_trace_pickle handler. The supported parameters + value are all lowercase. includecollectives={true, false} includestacktraces={true, false} onlyactive={true, false} Example post is: /handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true Test Plan: unit tests Differential Revision: [D58640474](https://our.internmc.facebook.com/intern/diff/D58640474) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128781 Approved by: https://github.com/d4l3k --- .../distributed/elastic/test_control_plane.py | 37 ++++++++++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 51 +++++++++++++++++++ .../distributed/c10d/ProcessGroupNCCL.cpp | 10 ---- .../c10d/control_plane/Handlers.hpp | 3 ++ .../c10d/control_plane/WorkerServer.cpp | 4 ++ 5 files changed, 95 insertions(+), 10 deletions(-) diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 7d01bd9eb03005..9eb57952e2bdfc 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -80,6 +80,43 @@ def test_dump_nccl_trace_pickle(self) -> None: resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") self.assertEqual(resp.status, 200) out = pickle.loads(resp.data) + self.assertIsInstance(out, dict) + self.assertIn("version", out) + + @requires_cuda + def test_dump_nccl_trace_pickle_with_params(self) -> None: + with local_worker_server() as pool: + # bad key - not lower case + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" + ) + self.assertEqual(resp.status, 400) + # unknown key + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" + ) + self.assertEqual(resp.status, 400) + # bad value - not a bool + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" + ) + self.assertEqual(resp.status, 400) + # bad value - value not lowercase + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" + ) + self.assertEqual(resp.status, 400) + # good key and value + resp = pool.request( + "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" + ) + self.assertEqual(resp.status, 200) + # multiple good keys and values + resp = pool.request( + "POST", + "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", + ) + self.assertEqual(resp.status, 200) def test_tcp(self) -> None: import requests diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 6507fe6abc2a2b..d3a997625e1440 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -1,7 +1,10 @@ #include +#include +#include #include #include +#include #ifdef USE_C10D_NCCL #include @@ -238,6 +241,54 @@ std::string getNcclErrorDetailStr( return interpret + err; } +control_plane::RegisterHandler dumpHandler{ + "dump_nccl_trace_pickle", + [](const control_plane::Request& req, control_plane::Response& res) { + const auto params = req.params(); + size_t validParamCount = 0; + + // valid params + const std::string includeCollectivesStr = "includecollectives"; + const std::string includeStackTracesStr = "includestacktraces"; + const std::string onlyActiveStr = "onlyactive"; + + std::unordered_map expectedParams = { + {includeCollectivesStr, true}, + {includeStackTracesStr, true}, + {onlyActiveStr, false}}; + + for (const auto& [paramName, paramValue] : params) { + auto it = expectedParams.find(paramName); + if (it != expectedParams.end()) { + validParamCount++; + if (paramValue == "true") { + it->second = true; + } else if (paramValue == "false") { + it->second = false; + } else { + res.setStatus(400); + res.setContent( + "Invalid value for " + paramName + + " valid values are true or false", + "text/plain"); + return; + } + } + } + if (validParamCount < params.size()) { + res.setStatus(400); + res.setContent( + "Invalid parameters - unexpected param passed in", "text/plain"); + return; + } + res.setContent( + dump_nccl_trace( + expectedParams[includeCollectivesStr], + expectedParams[includeStackTracesStr], + expectedParams[onlyActiveStr]), + "application/octet-stream"); + }}; + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d293c4d470b837..06804a544a388d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -29,7 +29,6 @@ #include #include #include -#include #include #include @@ -380,15 +379,6 @@ std::string dump_nccl_trace( } #endif -// TODO(c-p-i-o): add a JSON endpoint. -control_plane::RegisterHandler dumpHandler{ - "dump_nccl_trace_pickle", - [](const control_plane::Request& req, control_plane::Response& res) { - // TODO: c-p-i-o: params from the request need to go to dump_nccl_trace. - res.setContent( - dump_nccl_trace(true, true, false), "application/octet-stream"); - }}; - std::optional)>>& get_cpp_trace_dumper() { static std::optional< diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 0c10630549312f..f230e7a4c0e47f 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -15,6 +16,8 @@ class TORCH_API Request { virtual ~Request() = default; virtual const std::string& body() = 0; + + virtual const std::multimap& params() const = 0; }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 947a281982f14f..0e9de35322abb4 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -23,6 +23,10 @@ class RequestImpl : public Request { return req_.body; } + const std::multimap& params() const override { + return req_.params; + } + private: const httplib::Request& req_; }; From e3a39d49a0b06399f074b30c4be6ef9670633185 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 18 Jun 2024 06:22:14 +0000 Subject: [PATCH 124/171] [Traceable FSDP][Compiled Autograd] Add queue_callback() support (#126366) Adds support for `Variable._execution_engine.queue_callback()`, which is used in FSDP2. Important tests: - `pytest -rA test/inductor/test_compiled_autograd.py::TestCompiledAutograd::test_callback_graph_break_throws_error` - `pytest -rA test/inductor/test_compiled_autograd.py::TestAutogradWithCompiledAutograd::test_callback_adds_callback` - `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_callback_adds_callback` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126366 Approved by: https://github.com/xmfan --- test/inductor/test_compiled_autograd.py | 28 ++++++++++++++++- torch/_dynamo/compiled_autograd.py | 12 +++++++- torch/_dynamo/external_utils.py | 19 ++++++++++++ torch/_dynamo/side_effects.py | 12 ++++++++ torch/_dynamo/symbolic_convert.py | 2 ++ torch/_dynamo/variables/builder.py | 19 ++++++++++++ torch/_dynamo/variables/misc.py | 41 +++++++++++++++++++++++++ 7 files changed, 131 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index a3dfcb59f2fddb..91b8178ae6ccd5 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -1767,6 +1767,33 @@ def fn(inputs): out = compiled_fn(activations) self.assertTrue(len(activations) == 0) + def test_callback_graph_break_throws_error(self): + called = [0] + + def callback_final(): + called[0] += 1 + + class MyFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad): + torch.autograd.Variable._execution_engine.queue_callback(callback_final) + torch._dynamo.graph_break() + return grad + + a = torch.rand((3, 3), requires_grad=True) + with self.assertRaisesRegex( + AssertionError, + "only supported when Compiled Autograd is enabled with fullgraph=True", + ): + with compiled_autograd.enable(make_compiler_fn(fullgraph=False)): + b = MyFunc.apply(a) + b.sum().backward() + @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_cudagraphs_cpu_division(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -2177,7 +2204,6 @@ def wrap_test_class(orig_cls): "test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable( "test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance "test_autograd_simple_views_python", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function - "test_callback_adds_callback", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable "test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method "test_custom_autograd_no_early_free", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients "test_custom_function_cycle", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {} diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 2570278ef4788c..e72cf40d65ded8 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -4,7 +4,11 @@ from typing import Dict, List, Optional, TYPE_CHECKING import torch -from torch._dynamo.external_utils import call_backward, call_hook +from torch._dynamo.external_utils import ( + call_backward, + call_hook, + FakeCompiledAutogradEngine, +) from torch._dynamo.source import GetItemSource, LocalSource from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal from torch._logging import getArtifactLogger, trace_structured @@ -255,6 +259,12 @@ def move_graph_nodes_to_cuda(self, graph) -> List[int]: return [] def end_capture(self, outputs): + self.fx_tracer.create_proxy( + "call_function", + FakeCompiledAutogradEngine._exec_final_callbacks_stub, + (), + {}, + ) self.stack.close() self.fx_tracer.create_node( "output", diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index caea92bc6be082..7d3b0fc6ada43c 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -98,6 +98,25 @@ def untyped_storage_size(x: torch.Tensor): return x.untyped_storage().size() +class FakeCompiledAutogradEngine: + @staticmethod + def queue_callback(final_callbacks, cb): + final_callbacks.append(cb) + + @staticmethod + def exec_final_callbacks(final_callbacks): + i = 0 + while i < len(final_callbacks): + cb = final_callbacks[i] + cb() + i += 1 + final_callbacks.clear() + + @staticmethod + def _exec_final_callbacks_stub(): + pass + + def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs): return getattr(bw_state, hook_name)(*args, **kwargs) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 28ce9811b4c384..5689fa0977db87 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -89,6 +89,9 @@ def __init__( self.keepalive = keepalive or [] self.save_for_backward = save_for_backward or [] self.tensor_hooks = tensor_hooks or {} + # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. + # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. + self.ca_final_callbacks_var = None def __eq__(self, other: object) -> bool: assert isinstance(other, SideEffects) @@ -476,6 +479,15 @@ def codegen_hooks(self, cg): # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) + def get_ca_final_callbacks_var(self): + from .variables.base import MutableLocal + + if self.ca_final_callbacks_var is None: + self.ca_final_callbacks_var = variables.ListVariable( + [], mutable_local=MutableLocal() + ) + return self.ca_final_callbacks_var + def codegen_update_mutated(self, cg: PyCodegen): suffixes = [] for var in self._get_modified_vars(): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 7e129a05a09051..6105ae466b012b 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2305,6 +2305,7 @@ def __init__( self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} # Flag to indicate whether tracing is used for export. self.export = export + self.one_graph = False self.current_speculation = None @@ -2860,6 +2861,7 @@ def __init__( self.symbolic_result = None self.closure_cells = closure_cells self.nn_module_stack = parent.nn_module_stack.copy() + self.one_graph = parent.one_graph @property def fake_mode(self): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2097690b88b036..af91edb432c887 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -129,6 +129,7 @@ CollectiveFunctionRewriteVariable, FunctoolsPartialVariable, TritonKernelVariable, + UserFunctionVariable, UserMethodVariable, ) from .higher_order_ops import TorchHigherOrderOperatorVariable @@ -146,6 +147,7 @@ TupleVariable, ) from .misc import ( + AutogradEngineVariable, AutogradFunctionContextVariable, AutogradFunctionVariable, ComptimeVariable, @@ -726,6 +728,23 @@ def build_key_value(i, k, v): ), "apply", ) + elif isinstance(value, torch._C._ImperativeEngine): + self.install_guards(GuardBuilder.ID_MATCH) + return AutogradEngineVariable(value, source=self.source) + elif ( + value + is torch._dynamo.external_utils.FakeCompiledAutogradEngine._exec_final_callbacks_stub + ): + self.install_guards(GuardBuilder.FUNCTION_MATCH) + return LambdaVariable( + lambda: UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.exec_final_callbacks, + ).call_function( + self.tx, + (self.tx.output.side_effects.get_ca_final_callbacks_var(),), + {}, + ) + ) elif callable(value) and trace_rules.lookup_callable(value) is not None: if is_callable_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 179bb9a52bf98a..0e54e0f613a349 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -643,6 +643,47 @@ def var_getattr(self, tx, name): return super().var_getattr(tx, name) +class AutogradEngineVariable(UserDefinedObjectVariable): + """ + Represents a torch._C._ImperativeEngine instance. + """ + + def __init__( + self, + value, + value_type=None, + **kwargs, + ): + super().__init__(value=value, value_type=value_type, **kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "queue_callback": + if torch._dynamo.compiled_autograd.compiled_autograd_enabled: + assert ( + tx.one_graph + ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + return variables.UserFunctionVariable( + torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback, + source=self.source, + ).call_function( + tx, + (tx.output.side_effects.get_ca_final_callbacks_var(), *args), + kwargs, + ) + else: + unimplemented( + "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" + ) + else: + unimplemented(f"torch._C._ImperativeEngine method: {name}") + + class LambdaVariable(VariableTracker): def __init__(self, fn, **kwargs): super().__init__(**kwargs) From 60baeee59f7a6ff610c42411bf2709d2bbd5bd2c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 13:31:23 -0700 Subject: [PATCH 125/171] [BE] Skip the test if CUDA is not available (#128885) As title Differential Revision: [D58690210](https://our.internmc.facebook.com/intern/diff/D58690210/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128885 Approved by: https://github.com/wz337 --- test/distributed/_tensor/debug/test_comm_mode.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 5483b3171f3092..bd862220b210d3 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -116,6 +116,9 @@ def f(x, y): @requires_nccl() def test_comm_mode_with_c10d(self): + if not torch.cuda.is_available(): + return + world_pg = self.world_pg inp = torch.rand(2, 8, 16).cuda() From 6e43897912d149d7dad676f496c608fe32a31978 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 09:40:54 -0700 Subject: [PATCH 126/171] [BE][ptd_fb_test][3/N] Enable TestSlide for MultiThreadedTestCase (#128843) Enabling testslide for MultiThreadedTestCase, similar to https://github.com/pytorch/pytorch/pull/127512. Differential Revision: [D58677457](https://our.internmc.facebook.com/intern/diff/D58677457/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128843 Approved by: https://github.com/wz337 --- torch/testing/_internal/common_distributed.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 473e5c35e07a17..a0a3429797c284 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -994,10 +994,14 @@ def wrapper(self): return types.MethodType(wrapper, self) - def __init__(self, method_name: str = "runTest") -> None: + def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> None: + # methodName is the correct naming in unittest and testslide uses keyword arguments. + # So we need to use both to 1) not break BC and, 2) support testslide. + if methodName != "runTest": + method_name = methodName super().__init__(method_name) - test_fn = getattr(self, method_name, None) - setattr(self, method_name, self.join_or_run(test_fn)) + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) def perThreadSetUp(self): # super().setUp() # TestCase.setUp() calls torch.manual_seed() From 304c9345726e68c9bbd0ea370b3c056db6964c4b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 15 Jun 2024 17:15:38 -0700 Subject: [PATCH 127/171] Move MKLDNN Specific IR to Separate File (#126504) **Summary** Following the discussion in https://github.com/pytorch/pytorch/pull/122593#discussion_r1604144782, Move Inductor MKLDNN specific IRs to a separate file. Co-authored-by: Isuru Fernando Pull Request resolved: https://github.com/pytorch/pytorch/pull/126504 Approved by: https://github.com/desertfire, https://github.com/jgong5 ghstack dependencies: #126841, #126940 --- torch/_inductor/ir.py | 1632 -------------------------- torch/_inductor/mkldnn_ir.py | 1659 +++++++++++++++++++++++++++ torch/_inductor/mkldnn_lowerings.py | 26 +- 3 files changed, 1672 insertions(+), 1645 deletions(-) create mode 100644 torch/_inductor/mkldnn_ir.py diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9e1c90e9953783..898eed09268eae 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -81,7 +81,6 @@ get_kernel_metadata, is_dynamic, is_gpu, - pad_listlike, sympy_dot, sympy_index_symbol, sympy_index_symbol_with_prefix, @@ -5792,1637 +5791,6 @@ def get_inputs_that_alias_output(self): ] -def _prepare_convolution_fusion_create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding: List[int], - stride: List[int], - dilation: List[int], - groups: int, - transposed: bool = False, - output_padding: Optional[List[int]] = None, -): - """ - This function is a helper function to prepare inputs, layout and constant args - for convolution post-op fusion's create function, including deciding the output - layout (channels first or channels last), realizing inputs and make them etc. The - function only supports the CPU device since conv post-op fusion kernel is only - supported on CPU right now. - """ - - # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size - def _conv_input_size( - output_size, weight_size, padding, output_padding, stride, dilation, groups - ): - assert len(output_size) == len(weight_size), "Expect input dim == weight dim" - dim = len(output_size) - assert dim > 2, "Expect input dim > 2" - - BATCH_DIM = 0 - WEIGHT_INPUT_CHANNELS_DIM = 1 - input_size = [] - input_size.append(output_size[BATCH_DIM]) - input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) - for d in range(2, dim): - kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 - input_size_d = ( - (output_size[d] - 1) * stride[d - 2] - - (padding[d - 2] * 2) - + kernel - + output_padding[d - 2] - ) - input_size.append(input_size_d) - return list(map(int, input_size)) - - # The size of prepacked_weight is the prepacked weight size of deconv: - # Groups > 1: [g*o, i/g, ...] - # Groups == 1: [o, i, ...] - # Returns original weight size in [i, o, ...] - def _original_deconv_weight_size( - prepacked_weight, - groups, - ): - prepacked_weight_size = prepacked_weight.size() - dim = len(prepacked_weight_size) - assert dim > 2, "Expect weight dim > 2" - if groups > 1: - weight_size = [] - weight_size.append(prepacked_weight_size[1] * groups) - weight_size.append(prepacked_weight_size[0] / groups) - for d in range(2, dim): - weight_size.append(prepacked_weight_size[d]) - else: - weight_size = prepacked_weight.transpose(0, 1).size() - return weight_size - - x.realize() - weight.realize() - if bias is not None: - bias.realize() - with V.graph.fake_mode: - # TODO cleaned up the fake_tensor trace as Linear implementation - x_fake = ir_node_to_tensor(x, guard_shape=True) - weight_fake = ir_node_to_tensor(weight, guard_shape=True) - dims = len(x_fake.size()) - 2 - assert 0 < len(padding) <= dims - assert 0 < len(dilation) <= dims - assert 0 < len(stride) <= dims - padding = pad_listlike(padding, dims) - dilation = pad_listlike(dilation, dims) - stride = pad_listlike(stride, dims) - if output_padding is None: - output_padding = pad_listlike([0], dims) - else: - assert 0 < len(output_padding) <= dims - output_padding = pad_listlike(output_padding, dims) - assert isinstance(groups, int) - if transposed: - # When transposed, the size of the prepacked oneDNN weight is different - # from the PyTorch weight. We're not able to run aten conv with such - # size. We infer the output size from the input params here: - weight_size = _original_deconv_weight_size(weight_fake, groups) - input_size = x_fake.size() - output_size = _conv_input_size( - input_size, - weight_size, - padding, - output_padding, - stride, - dilation, - groups, - ) - else: - bias_fake = ( - ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias - ) - output = torch.ops.aten.convolution( - x_fake, - weight_fake, - bias_fake, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - output_size = output.size() - - req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) - req_stride_order = [len(req_stride_order)] + req_stride_order - - x = cls.require_stride_order(x, req_stride_order) - - # We won't do weight prepack for Conv if dynamic_shapes. - # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. - # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), - # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order - # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, - # this tensor is considered as channels first and the output will be in contiguous format. - # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. - dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) - if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = FlexibleLayout.contiguous_strides(output_size) - else: - output_stride = make_channels_last_strides_for(output_size) - - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] - - kernel_layout = FixedLayout( - x.get_device(), - x.get_dtype(), - convert_shape_to_inductor(output_size), - convert_shape_to_inductor(output_stride), - ) - constant_args = [padding, stride, dilation, groups] - if transposed: - constant_args.insert(1, output_padding) - - if bias is not None: - inputs.append(bias) - else: - constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout, req_stride_order - - -def _prepare_linear_fusion_create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", -): - """ - This function is a helper function to prepare inputs, layout and constant args - for linear post-op fusion's create function. The function only supports the CPU device - since linear post-op fusion kernel is only supported on CPU right now. - """ - x.realize() - weight.realize() - if bias is not None: - bias.realize() - - *m, _ = x.get_size() - # The weight has been transposed during the qlinear weight prepack process. - # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ - # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 - _, oc = weight.get_size() - output_size = list(m) + [oc] - req_stride_order = list(reversed(range(len(x.get_size())))) - - x = cls.require_stride_order(x, req_stride_order) - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] - - output_stride = FlexibleLayout.contiguous_strides(output_size) - kernel_layout = FixedLayout( - x.get_device(), - x.get_dtype(), - output_size, - output_stride, - ) - constant_args: List[Any] = [] - - if bias is not None: - inputs.append(bias) - else: - constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout, req_stride_order - - -class ConvolutionUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise", - cpp_kernel_name="mkldnn::_convolution_pointwise", - ) - self.cpp_kernel_key = "convolution_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - attr, - scalars: Optional[List[Any]], - algorithm, - ): - (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - constant_args = constant_args + [ - attr, - may_convert_to_optional(scalars), - algorithm, - ] - return ConvolutionUnary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class ConvolutionBinary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - cpp_constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", - cpp_kernel_name="mkldnn::_convolution_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "convolution_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& other_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional unary_attr, - torch::List> unary_scalars, - c10::optional unary_algorithm)""" - self.cpp_constant_args = cpp_constant_args - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - other: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - binary_attr: str, - binary_alpha: Optional[float], - unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], - unary_algorithm: Optional[str], - ): - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - other = cls.require_stride_order(other, req_stride_order) - inputs.insert(1, other) - constant_args = constant_args + [ - binary_attr, - binary_alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - return ConvolutionBinary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class ConvolutionBinaryInplace(ExternKernelAlloc): - def __init__( - self, - kernel_layout, - inputs, - constant_args=(), - ): - # Due to constrain of op.call, other (Tensor&) should be at input[0] - reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] - - super().__init__( - kernel_layout, - reordered_inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", - cpp_kernel_name="mkldnn::_convolution_pointwise_", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "convolution_pointwise_binary_" - # TODO: op.call: input[0] should be at::Tensor& - self.cpp_op_schema = """ - at::Tensor&( - at::Tensor& other_t, - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional unary_attr, - torch::List> unary_scalars, - c10::optional unary_algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - - def get_mutation_names(self): - return [self.inputs[0].get_name()] - - def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: - return set() - - @classmethod - def create( - cls, - x: "TensorBox", - other: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups: int, - binary_attr: str, - binary_alpha: Optional[float], - unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], - unary_algorithm: Optional[str], - ): - ( - inputs, - constant_args, - _, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, x, weight, bias, padding_, stride_, dilation_, groups - ) - other = cls.require_stride_order(other, req_stride_order) - inputs.insert(1, other) - constant_args = constant_args + [ - binary_attr, - binary_alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - packed = ConvolutionBinaryInplace( - kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] - inputs=inputs, - constant_args=constant_args, - ) - mark_node_as_mutating(packed, inputs[1]) - # This op mutates in place which means that the result is not the - # target but rather the input that is being mutated - # init reorders the inputs, so inputs[1] becomes packed.inputs[0] - return packed.inputs[0] - - -class MKLPackedLinear(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkl._mkl_linear", - cpp_kernel_name="mkl::_mkl_linear", - ) - self.cpp_kernel_key = "mkl_linear" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& self, - const at::Tensor& mkl_weight_t, - const at::Tensor& origin_weight_t, - const c10::optional& bias_opt, - const int64_t prepack_batch_size)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create(cls, x, packed_w, orig_w, B, batch_size): - x = cls.require_stride1(cls.realize_input(x)) - orig_w = cls.require_stride1(cls.realize_input(orig_w)) - *m, _ = x.get_size() - oc, _ = orig_w.get_size() - output_size = list(m) + [oc] - output_stride = FlexibleLayout.contiguous_strides(output_size) - inputs = [x, packed_w, orig_w] - constant_args = [batch_size] - if B is not None: - inputs += [B] - else: - constant_args.insert(0, None) - - return MKLPackedLinear( - layout=FixedLayout( - x.get_device(), x.get_dtype(), output_size, output_stride - ), - inputs=inputs, - constant_args=constant_args, - ) - - -class LinearUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._linear_pointwise", - cpp_kernel_name="mkldnn::_linear_pointwise", - ) - self.cpp_kernel_key = "linear_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create(cls, x, w, B, attr, scalars, algorithm): - x = cls.require_contiguous(cls.realize_input(x)) - w = cls.require_contiguous(cls.realize_input(w)) - - *m, ic = x.get_size() - oc, ic = w.get_size() - inputs = [x, w] - constant_args = [attr, scalars if scalars else [-1], algorithm] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) - else: - constant_args.insert(0, None) - - return LinearUnary( - layout=FlexibleLayout( - device=x.get_device(), - dtype=x.get_dtype(), - size=list(m) + [oc], - ), - inputs=inputs, - constant_args=constant_args, - ) - - def apply_constraint(self): - pass - - -class LinearBinary(ExternKernelAlloc): - kernel = "torch.ops.mkldnn._linear_pointwise.binary" - - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", - cpp_kernel_name="mkldnn::_linear_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "linear_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& other_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - c10::string_view attr) - """ - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - - @classmethod - def create(cls, x, y, w, B, attr): - x = cls.require_contiguous(cls.realize_input(x)) - y = cls.require_contiguous(cls.realize_input(y)) - w = cls.require_contiguous(cls.realize_input(w)) - - *m, ic = x.get_size() - oc, ic = w.get_size() - - inputs = [x, y, w] - constant_args = [attr] - if B is not None: - B = cls.require_contiguous(cls.realize_input(B)) - inputs.append(B) - else: - constant_args.insert(0, B) - - return LinearBinary( - layout=FlexibleLayout( - device=x.get_device(), - dtype=x.get_dtype(), - size=list(m) + [oc], - ), - inputs=inputs, - constant_args=constant_args, - ) - - def apply_constraint(self): - pass - - -class ConvolutionTransposeUnary(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", - cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", - ) - self.cpp_kernel_key = "convolution_transpose_pointwise" - self.cpp_op_schema = """ - at::Tensor( - const at::Tensor& input_t, - const at::Tensor& weight_t, - const c10::optional& bias_opt, - at::IntArrayRef padding, - at::IntArrayRef output_padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, - int64_t groups, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) - - @classmethod - def create( - cls, - x: "TensorBox", - weight: "TensorBox", - bias: "TensorBox", - padding_: List[int], - output_padding_: List[int], - stride_: List[int], - dilation_: List[int], - groups_: int, - attr, - scalars: Optional[List[Any]], - algorithm, - ): - transposed = True - ( - inputs, - constant_args, - kernel_layout, - _, - ) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups_, - transposed, - output_padding_, - ) - constant_args = constant_args + [ - attr, - may_convert_to_optional(scalars), - algorithm, - ] - return ConvolutionTransposeUnary( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class MkldnnRnnLayer(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="aten.mkldnn_rnn_layer", - cpp_kernel_name="at::mkldnn_rnn_layer", - ) - - @classmethod - def create( - cls, - x: "TensorBox", - w0: "TensorBox", - w1: "TensorBox", - w2: "TensorBox", - w3: "TensorBox", - hx: "TensorBox", - cx: "TensorBox", - reverse: bool, - batch_sizes: List[int], - mode: int, - hidden_size: int, - num_layers: int, - has_biases: bool, - bidirectional: bool, - batch_first: bool, - train: bool, - ): - x = cls.require_stride1(cls.realize_input(x)) - # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. - # Make sure x is contiguous in batch_first case. - x.freeze_layout() - w0 = cls.require_stride1(cls.realize_input(w0)) - w1 = cls.require_stride1(cls.realize_input(w1)) - w2 = cls.require_stride1(cls.realize_input(w2)) - w3 = cls.require_stride1(cls.realize_input(w3)) - hx = cls.require_stride1(cls.realize_input(hx)) - hx.freeze_layout() - cx = cls.require_stride1(cls.realize_input(cx)) - cx.freeze_layout() - - input_size = x.get_size() - assert len(input_size) == 3, "Expect lstm input to be 3D" - # batch_first is handled in the lstm OP. When entering - # rnn_layer here, we'll always have batch_first = False - seq_length, mini_batch, input_size = input_size - output_shape = [seq_length, mini_batch, hidden_size] - - hy_shape = hx.get_size() - cy_shape = cx.get_size() - - res: List[IRNode] = [] - - inputs = [x, w0, w1, w2, w3, hx, cx] - constant_args = [ - reverse, - batch_sizes, - mode, - hidden_size, - num_layers, - has_biases, - bidirectional, - batch_first, - train, - ] - - packed = MkldnnRnnLayer( - MultiOutputLayout(x.get_device()), - inputs=inputs, - constant_args=constant_args, - ) - - def get_strides_of_lstm_output(output_shape, batch_first): - assert len(output_shape) == 3, "Expect output_shape to be 3D" - return FlexibleLayout.contiguous_strides(output_shape) - - output_sizes = [output_shape, hy_shape, cy_shape] - output_strides = [ - get_strides_of_lstm_output(output_shape, batch_first), - FlexibleLayout.contiguous_strides(hy_shape), - FlexibleLayout.contiguous_strides(cy_shape), - ] - output_ir = [ - MultiOutput( - FixedLayout( - x.get_device(), - x.get_dtype(), - output_size, - output_stride, - ), - packed, - [(tuple, i)], - ) - for i, (output_size, output_stride) in enumerate( - zip(output_sizes, output_strides) - ) - ] - - return output_ir - - -class QConvPointWisePT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp] - - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp] - - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = len(inputs) == 5 - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.onednn.qconv2d_pointwise", - cpp_kernel_name="onednn::qconv2d_pointwise", - ) - self.cpp_kernel_key = "qconv2d_pointwise" - self.cpp_op_schema = """ - at::Tensor( - at::Tensor act, - double act_scale, - int64_t act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - torch::List stride, - torch::List padding, - torch::List dilation, - int64_t groups, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp = args[-2], args[-1] - ( - stride, - padding, - dilation, - groups, - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-12:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - stride, - padding, - dilation, - groups, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - stride_: List[int], - padding_: List[int], - dilation_: List[int], - groups: int, - o_inv_scale: float, - output_zero_point: int, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ): - transposed = False - output_padding = None - (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups, - transposed, - output_padding, - ) - # swap padding and stride to align with functional conv arg order - if bias is None: - constant_args[1], constant_args[2] = constant_args[2], constant_args[1] - else: - constant_args[0], constant_args[1] = constant_args[1], constant_args[0] - - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - x_scale, - x_zp, - o_inv_scale, - output_zero_point, - output_dtype, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. - kernel_layout.dtype = output_dtype - - return QConvPointWisePT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - ) - - -class QConvPointWiseBinaryPT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - ): - """ - Needs input/weight/output qparams - if bias is not None - - inputs = [x, w, b, accum, w_scale, w_zp] - - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, accum, w_scale, w_zp] - - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, - accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = len(inputs) == 6 - self.idx_for_inplace_sum = 3 if self.has_bias else 2 - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", - cpp_kernel_name="onednn::qconv2d_pointwise", - ) - self.cpp_kernel_overload_name = "binary" - self.cpp_kernel_key = "qconv2d_pointwise_binary" - self.cpp_op_schema = """ - at::Tensor( - at::Tensor act, - double act_scale, - int64_t act_zero_point, - at::Tensor accum, - double accum_scale, - int64_t accum_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - torch::List stride, - torch::List padding, - torch::List dilation, - int64_t groups, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view binary_attr, - c10::optional alpha, - c10::optional attr, - torch::List> scalars, - c10::optional algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - accum, w_scale, w_zp = args[-3], args[-2], args[-1] - ( - stride, - padding, - dilation, - groups, - x_scale, - x_zp, - accum_scale, - accum_zp, - o_inv_scale, - o_zp, - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-16:] - conv_args = ( - x, - x_scale, - x_zp, - accum, - accum_scale, - accum_zp, - packed_weight, - w_scale, - w_zp, - bias, - stride, - padding, - dilation, - groups, - o_inv_scale, - o_zp, - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - conv_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - def get_mutation_names(self): - return [self.inputs[self.idx_for_inplace_sum].get_name()] - - def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: - return set() - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale, - x_zp, - accum: "TensorBox", - accum_scale, - accum_zp, - weight: "TensorBox", # packed_weight - w_scale, - w_zp, - bias: "TensorBox", - stride_: List[int], - padding_: List[int], - dilation_: List[int], - groups: int, - o_inv_scale: "TensorBox", - output_zero_point: "TensorBox", - output_dtype, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ): - transposed = False - output_padding = None - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_convolution_fusion_create( - cls, - x, - weight, - bias, - padding_, - stride_, - dilation_, - groups, - transposed, - output_padding, - ) - - accum = cls.require_stride_order(accum, req_stride_order) - inputs.append(accum) - - # swap padding and stride to align with functional conv arg order - if bias is None: - constant_args[1], constant_args[2] = constant_args[2], constant_args[1] - else: - constant_args[0], constant_args[1] = constant_args[1], constant_args[0] - - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - x_scale, - x_zp, - accum_scale, - accum_zp, - o_inv_scale, - output_zero_point, - output_dtype, - binary_attr, - alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - assert ( - binary_attr == "sum" - ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." - - packed = QConvPointWiseBinaryPT2E( - layout=NoneLayout(accum.get_device()), - inputs=inputs, - constant_args=constant_args, - ) - mark_node_as_mutating(packed, accum) - - # Return accum since it has been inplace changed. - return packed.inputs[packed.idx_for_inplace_sum] - - -class QLinearPointwisePT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - has_bias=True, - x_scale_zp_are_tensors=False, - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp] - - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp] - - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name=( - "torch.ops.onednn.qlinear_pointwise.tensor" - if x_scale_zp_are_tensors - else "torch.ops.onednn.qlinear_pointwise.default" - ), - cpp_kernel_name="onednn::qlinear_pointwise", - ) - self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" - self.cpp_kernel_key = "qlinear_pointwise" - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") - ) - self.cpp_op_schema = f""" - at::Tensor( - at::Tensor act, - {x_scale_type_str} act_scale, - {x_zp_type_str} act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - double output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::string_view post_op_name, - torch::List> post_op_args, - c10::string_view post_op_algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp = args[-2], args[-1] - if self.x_scale_zp_are_tensors: - assert len(args) >= 4 - x_scale, x_zp = args[-4], args[-3] - ( - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-6:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-8:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - o_inv_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - o_inv_scale: float, - output_zero_point: int, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ): - (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( - cls, - x, - weight, - bias, - ) - - if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): - x_scale.realize() - x_zp.realize() - inputs = inputs + [x_scale, x_zp] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zp, int) - constant_args = constant_args + [x_scale, x_zp] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - constant_args = constant_args + [ - o_inv_scale, - output_zero_point, - output_dtype, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = output_dtype - - return QLinearPointwisePT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - - -class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): - def __init__( - self, - layout, - inputs, - constant_args=(), - has_bias=True, - x_scale_zp_are_tensors=False, - ): - """ - if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp, x2] - - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - else - - inputs = [x, w, weight_scale, weight_zp, x2] - - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, - fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] - """ - self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors - super().__init__( - layout, - inputs, - constant_args, - None, - python_kernel_name=( - "torch.ops.onednn.qlinear_pointwise.binary_tensor" - if x_scale_zp_are_tensors - else "torch.ops.onednn.qlinear_pointwise.binary" - ), - cpp_kernel_name="onednn::qlinear_pointwise", - ) - self.cpp_kernel_overload_name = ( - "binary_tensor" if x_scale_zp_are_tensors else "binary" - ) - self.cpp_kernel_key = "qlinear_pointwise_binary" - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") - ) - self.cpp_op_schema = f""" - at::Tensor( - at::Tensor act, - {x_scale_type_str} act_scale, - {x_zp_type_str} act_zero_point, - at::Tensor weight, - at::Tensor weight_scales, - at::Tensor weight_zero_points, - c10::optional bias, - double inv_output_scale, - int64_t output_zero_point, - c10::optional output_dtype, - c10::optional other, - double other_scale, - int64_t other_zero_point, - c10::string_view binary_post_op, - double binary_alpha, - c10::string_view unary_post_op, - torch::List> unary_post_op_args, - c10::string_view unary_post_op_algorithm)""" - - def codegen(self, wrapper): - # Parser the inputs and constant - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] - w_scale, w_zp, other = args[-3], args[-2], args[-1] - if self.x_scale_zp_are_tensors: - assert len(args) >= 5 - x_scale, x_zp = args[-5], args[-4] - ( - o_inv_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-10:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_inv_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-12:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - o_inv_scale, - o_zp, - output_dtype, - other, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) - - @classmethod - def create( - cls, - x: "TensorBox", - x_scale: float, - x_zp: int, - weight: "TensorBox", # packed_weight - w_scale: "TensorBox", - w_zp: "TensorBox", - bias: "TensorBox", - o_inv_scale: float, - output_zero_point: int, - output_dtype, - other: "TensorBox", - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ): - ( - inputs, - constant_args, - kernel_layout, - req_stride_order, - ) = _prepare_linear_fusion_create( - cls, - x, - weight, - bias, - ) - - if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): - x_scale.realize() - x_zp.realize() - inputs = inputs + [x_scale, x_zp] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zp, int) - constant_args = constant_args + [x_scale, x_zp] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zp.realize() - inputs = inputs + [w_scale, w_zp] - if binary_attr == "sum": - other = cls.require_stride_order(other, req_stride_order) - inputs.append(other) - constant_args = constant_args + [ - o_inv_scale, - output_zero_point, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - may_convert_to_optional(unary_scalars), - unary_algorithm, - ] - - if binary_attr == "sum": - packed = QLinearPointwiseBinaryPT2E( - layout=NoneLayout(other.get_device()), - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - mark_node_as_mutating(packed, other) - # Return other since it has been inplace changed. - return packed.inputs[-1] - - if output_dtype is not None: - assert output_dtype in [torch.float32, torch.bfloat16] - # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = output_dtype - - return QLinearPointwiseBinaryPT2E( - layout=kernel_layout, - inputs=inputs, - constant_args=constant_args, - has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, - ) - - @dataclasses.dataclass class MutableBox(IRNode): """ diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py new file mode 100644 index 00000000000000..36be03772e8995 --- /dev/null +++ b/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1659 @@ +# mypy: allow-untyped-defs +from typing import Any, List, Optional, Set + +import sympy + +import torch + +from torch._prims_common import make_channels_last_strides_for + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + ir_node_to_tensor, + IRNode, + is_contiguous_storage_and_layout, + Layout, + mark_node_as_mutating, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + NoneLayout, + TensorBox, +) + +from .utils import convert_shape_to_inductor, pad_listlike + +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: List[int], + stride: List[int], + dilation: List[int], + groups: int, + transposed: bool = False, + output_padding: Optional[List[int]] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + for d in range(2, dim): + weight_size.append(prepacked_weight_size[d]) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, int) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + output_size = output.size() + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if dynamic_shapes and is_contiguous_storage_and_layout(x): + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: List[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_key = "convolution_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ): + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", + cpp_kernel_name="mkldnn::_convolution_pointwise_", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary_" + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + def get_mutation_names(self): + return [self.inputs[0].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, inputs[1]) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", + cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", + ) + self.cpp_kernel_key = "convolution_transpose_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + output_padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups_: int, + attr, + scalars: Optional[List[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + return ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_key = "qconv2d_pointwise" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + x_scale, + x_zero_point, + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, w, b, accum, w_scale, w_zp] + - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, accum, w_scale, w_zp] + - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, + accum_zp, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = len(inputs) == 6 + self.idx_for_inplace_sum = 3 if self.has_bias else 2 + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", + cpp_kernel_name="onednn::qconv2d_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "qconv2d_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + at::Tensor act, + double act_scale, + int64_t act_zero_point, + at::Tensor accum, + double accum_scale, + int64_t accum_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + accum, w_scale, w_zp = args[-3], args[-2], args[-1] + ( + stride, + padding, + dilation, + groups, + x_scale, + x_zp, + accum_scale, + accum_zp, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-16:] + conv_args = ( + x, + x_scale, + x_zp, + accum, + accum_scale, + accum_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_scale, + o_zp, + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + conv_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self): + return [self.inputs[self.idx_for_inplace_sum].get_name()] + + def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: + return set() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale, + x_zero_point, + qaccum: "TensorBox", + accum_scale, + accum_zero_point, + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + bias: "TensorBox", + stride: List[int], + padding: List[int], + dilation: List[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + ) + + qaccum = cls.require_stride_order(qaccum, req_stride_order) + inputs.append(qaccum) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + x_scale, + x_zero_point, + accum_scale, + accum_zero_point, + output_scale, + output_zero_point, + output_dtype, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert ( + binary_attr == "sum" + ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + mark_node_as_mutating(packed, qaccum) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkl._mkl_linear", + cpp_kernel_name="mkl::_mkl_linear", + ) + self.cpp_kernel_key = "mkl_linear" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& self, + const at::Tensor& mkl_weight_t, + const at::Tensor& origin_weight_t, + const c10::optional& bias_opt, + const int64_t prepack_batch_size)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_key = "linear_pointwise" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr, + torch::List> scalars, + c10::optional algorithm)""" + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + ) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", + cpp_kernel_name="mkldnn::_linear_pointwise", + ) + self.cpp_kernel_overload_name = "binary" + self.cpp_kernel_key = "linear_pointwise_binary" + self.cpp_op_schema = """ + at::Tensor( + const at::Tensor& input_t, + const at::Tensor& other_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + c10::string_view attr) + """ + + def codegen(self, wrapper): + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + ) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.default" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" + self.cpp_kernel_key = "qlinear_pointwise" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp = args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 4 + x_scale, x_zp = args[-4], args[-3] + ( + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-6:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-8:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + x_scale_zp_are_tensors=False, + ): + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp, x2] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp, x2] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] + """ + self.has_bias = has_bias + self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name=( + "torch.ops.onednn.qlinear_pointwise.binary_tensor" + if x_scale_zp_are_tensors + else "torch.ops.onednn.qlinear_pointwise.binary" + ), + cpp_kernel_name="onednn::qlinear_pointwise", + ) + self.cpp_kernel_overload_name = ( + "binary_tensor" if x_scale_zp_are_tensors else "binary" + ) + self.cpp_kernel_key = "qlinear_pointwise_binary" + x_scale_type_str, x_zp_type_str = ( + ("at::Tensor", "at::Tensor") + if x_scale_zp_are_tensors + else ("double", "int64_t") + ) + self.cpp_op_schema = f""" + at::Tensor( + at::Tensor act, + {x_scale_type_str} act_scale, + {x_zp_type_str} act_zero_point, + at::Tensor weight, + at::Tensor weight_scales, + at::Tensor weight_zero_points, + c10::optional bias, + double inv_output_scale, + int64_t output_zero_point, + c10::optional output_dtype, + c10::optional other, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, + double binary_alpha, + c10::string_view unary_post_op, + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm)""" + + def codegen(self, wrapper): + # Parser the inputs and constant + args = [x.codegen_reference() for x in self.inputs] + const_args = [] + const_args.extend(self.codegen_const_args()) + + x = args[0] + packed_weight = args[1] + bias = args[2] if self.has_bias else const_args[0] + w_scale, w_zp, other = args[-3], args[-2], args[-1] + if self.x_scale_zp_are_tensors: + assert len(args) >= 5 + x_scale, x_zp = args[-5], args[-4] + ( + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-10:] + else: + assert len(const_args) >= 8 + ( + x_scale, + x_zp, + o_scale, + o_zp, + output_dtype, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) = const_args[-12:] + + codegen_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + o_scale, + o_zp, + output_dtype, + other, + other_scale, + other_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.python_kernel_name, + self.cpp_kernel_name, + codegen_args, + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overload_name, + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: float, + x_zero_point: int, + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other: "TensorBox", + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + ) + + if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): + x_scale.realize() + x_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + x_scale_zp_are_tensors = True + else: + assert isinstance(x_scale, float) and isinstance(x_zero_point, int) + constant_args = constant_args + [x_scale, x_zero_point] + x_scale_zp_are_tensors = False + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [w_scale, w_zero_point] + if binary_post_op == "sum": + other = cls.require_stride_order(other, req_stride_order) + inputs.append(other) + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + mark_node_as_mutating(packed, other) + # Return other since it has been inplace changed. + return packed.inputs[-1] + + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + x_scale_zp_are_tensors=x_scale_zp_are_tensors, + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ): + super().__init__( + layout, + inputs, + constant_args, + None, + python_kernel_name="aten.mkldnn_rnn_layer", + cpp_kernel_name="at::mkldnn_rnn_layer", + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: List[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + w0 = cls.require_stride1(cls.realize_input(w0)) + w1 = cls.require_stride1(cls.realize_input(w1)) + w2 = cls.require_stride1(cls.realize_input(w2)) + w3 = cls.require_stride1(cls.realize_input(w3)) + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + res: List[IRNode] = [] + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + packed = MkldnnRnnLayer( + MultiOutputLayout(x.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + output_sizes = [output_shape, hy_shape, cy_shape] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + + return output_ir diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 809c8f10b711d2..c006af0095e6c4 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -4,7 +4,7 @@ import torch import torch.utils._pytree as pytree from torch._inductor.kernel.mm_common import mm_args -from . import ir +from . import ir, mkldnn_ir from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox from .lowering import ( @@ -173,13 +173,13 @@ def register_onednn_fusion_ops(): torch.ops.mkldnn._linear_pointwise, "mkldnn::_linear_pointwise", has_out_variant=False, - kernel_creator=ir.LinearUnary.create, + kernel_creator=mkldnn_ir.LinearUnary.create, ) aten_mkldnn_linear_binary = ExternKernelChoice( torch.ops.mkldnn._linear_pointwise.binary, "mkldnn::_linear_pointwise", has_out_variant=False, - kernel_creator=ir.LinearBinary.create, + kernel_creator=mkldnn_ir.LinearBinary.create, ) cpu_needs_realized_inputs = [ torch.ops.mkldnn._convolution_pointwise, @@ -204,7 +204,7 @@ def convolution_unary( algorithm, ): return TensorBox.create( - ir.ConvolutionUnary.create( + mkldnn_ir.ConvolutionUnary.create( x, weight, bias, @@ -235,7 +235,7 @@ def convolution_binary( unary_algorithm, ): return TensorBox.create( - ir.ConvolutionBinary.create( + mkldnn_ir.ConvolutionBinary.create( x, other, weight, @@ -269,7 +269,7 @@ def convolution_binary_inplace( unary_algorithm, ): return TensorBox.create( - ir.ConvolutionBinaryInplace.create( + mkldnn_ir.ConvolutionBinaryInplace.create( x, other, weight, @@ -429,7 +429,7 @@ def convolution_transpose_unary( algorithm, ): return TensorBox.create( - ir.ConvolutionTransposeUnary.create( + mkldnn_ir.ConvolutionTransposeUnary.create( x, weight, bias, @@ -465,7 +465,7 @@ def mkldnn_rnn_layer( ): return pytree.tree_map( TensorBox.create, - ir.MkldnnRnnLayer.create( + mkldnn_ir.MkldnnRnnLayer.create( x, w0, w1, @@ -506,7 +506,7 @@ def qconvolution_unary( algorithm, ): return TensorBox.create( - ir.QConvPointWisePT2E.create( + mkldnn_ir.QConvPointWisePT2E.create( x, x_scale, x_zp, @@ -566,7 +566,7 @@ def qconvolution_binary( # we will do accum dtype convertion here. accum = to_dtype(accum, output_dtype) return TensorBox.create( - ir.QConvPointWiseBinaryPT2E.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( x, x_scale, x_zp, @@ -609,7 +609,7 @@ def qlinear_unary( algorithm, ): return TensorBox.create( - ir.QLinearPointwisePT2E.create( + mkldnn_ir.QLinearPointwisePT2E.create( x, x_scale, x_zp, @@ -668,7 +668,7 @@ def qlinear_binary( x2.get_dtype() == output_dtype ), "dtype of accum for qlinear post op sum should be the same as output" return TensorBox.create( - ir.QLinearPointwiseBinaryPT2E.create( + mkldnn_ir.QLinearPointwiseBinaryPT2E.create( x, x_scale, x_zp, @@ -695,7 +695,7 @@ def qlinear_binary( torch.ops.mkl._mkl_linear, "mkl::_mkl_linear", has_out_variant=False, - kernel_creator=ir.MKLPackedLinear.create, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) From 3dd5f0ecbbb71e3f8edc134baf5fe9fcf638ad07 Mon Sep 17 00:00:00 2001 From: Ahmed Gheith Date: Tue, 18 Jun 2024 12:30:13 +0000 Subject: [PATCH 128/171] Remove circular import (#128875) Summary: A spurious import is causing circular dependency errors Test Plan: phabricator signals Differential Revision: D58685676 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128875 Approved by: https://github.com/kit1980 --- torch/optim/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 341d07b1a2e825..f794a1eafe243a 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -37,6 +37,3 @@ del optimizer # type: ignore[name-defined] # noqa: F821 del nadam # type: ignore[name-defined] # noqa: F821 del lbfgs # type: ignore[name-defined] # noqa: F821 - - -import torch.optim._multi_tensor From f2805a0408cfeb01c4d77a960dd4ca8e9a49db49 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Mon, 17 Jun 2024 13:18:20 -0700 Subject: [PATCH 129/171] [FSDP2] Added APIs for explicit fwd/bwd prefetching (#128884) This PR adds two APIs `set_modules_to_forward_prefetch` and `set_modules_to_backward_prefetch` to enable explicit forward/backward all-gather prefetching, respectively. ``` def set_modules_to_forward_prefetch(self, modules: List[FSDPModule]): -> None def set_modules_to_backward_prefetch(self, modules: List[FSDPModule]): -> None ``` **Motivation** FSDP2 implements _reasonable defaults_ for forward and backward prefetching. In forward, it uses implicit prefetching and allows two all-gather output tensors to be alive at once (so that the current all-gather copy-out can overlap with the next all-gather). In backward, it uses explicit prefetching based on the reverse post-forward order. However, there may be cases where with expert knowledge, we can reduce communication bubbles by moving all-gathers manually. One way to expose such behavior is to expose _prefetching limits_, i.e. integers that configure how many outstanding all-gathers/all-gather output tensors can be alive at once. IMIHO, this leans toward _easy_, not _simple_ (see [PyTorch design principles](https://pytorch.org/docs/stable/community/design.html#principle-2-simple-over-easy)). The crux of the problem is that there may be special cases where manual intervention can give better performance. Exposing a prefetching limit and allowing users to pass a value >1 just smooths over the problem since such a limit would generally apply over the entire model even though it possibly should not. Then, expert users will see a specific all-gather that they want to deviate from this limit, and there is little we can do. Thus, we instead choose to expose the most primitive extension point: namely, every `FSDPModule` gives an opportunity to prefetch other all-gathers in forward and in backward. How to leverage this extension point is fully up to the user. Implementing the prefetch limit can be done using this extension point (e.g. record the post-forward order yourself using forward hooks, iterate over that order, and call the `set_modules_to_forward_prefetch` / `set_modules_to_backward_prefetch` APIs). Differential Revision: [D58700346](https://our.internmc.facebook.com/intern/diff/D58700346) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128884 Approved by: https://github.com/ckluk2, https://github.com/weifengpy --- .../_composable/fsdp/test_fully_shard_comm.py | 205 +++++++++++++++++- .../fsdp/test_fully_shard_training.py | 40 +++- .../_composable/fsdp/_fsdp_param_group.py | 29 ++- .../_composable/fsdp/_fsdp_state.py | 11 +- .../_composable/fsdp/fully_shard.py | 48 +++- torch/testing/_internal/common_fsdp.py | 13 ++ 6 files changed, 334 insertions(+), 12 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 5acb9d895b4132..c0e3fbc9aea885 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -43,6 +43,7 @@ FSDPTestMultiThread, MLP, patch_post_backward, + patch_reshard, patch_unshard, ) from torch.testing._internal.common_utils import run_tests @@ -372,7 +373,7 @@ def test_manual_reshard_with_reshard_after_forward_false(self): ) -class TestFullyShardBackwardPrefetch(FSDPTest): +class TestFullyShardPrefetch(FSDPTest): @property def world_size(self) -> int: return min(4, torch.cuda.device_count()) @@ -578,6 +579,193 @@ def _test_backward_prefetch_unused_in_backward( self.assertEqual(events, expected_events) events.clear() + @skip_if_lt_x_gpu(2) + def test_set_modules_to_forward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure forward prefetching: + # each transformer block (layer) prefetches for the next few + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_backward_events = [ + # Default backward prefetching + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_forward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + set_forward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + expected_forward_events = [ + ("unshard", "", TrainingState.FORWARD), + # `layers.i` prefetches `layers.i+1` and `layers.i+2` + ("unshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + self.assertEqual(events, expected_backward_events) + events.clear() + + @skip_if_lt_x_gpu(2) + def test_set_modules_to_backward_prefetch(self): + n_layers = 4 + reshard_after_forward = True + checkpoint_impl = "utils" + model, _, inp = self._init_transformer( + n_layers, reshard_after_forward, checkpoint_impl + ) + + def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None: + # Use model-specific knowledge to configure backward prefetching: + # each transformer block (layer) prefetches for the previous few + for i, layer in enumerate(model.layers): + if i < num_to_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + events: List[EventType] = [] + unshard_with_record = self._get_unshard_with_record( + FSDPParamGroup.unshard, events + ) + reshard_with_record = self._get_reshard_with_record( + FSDPParamGroup.reshard, events + ) + post_backward_with_record = self._get_post_backward_with_record( + FSDPParamGroup.post_backward, events + ) + expected_forward_events = [ + # Default forward prefetching + ("unshard", "", TrainingState.FORWARD), # root + ("unshard", "layers.0", TrainingState.FORWARD), + ("reshard", "layers.0", TrainingState.FORWARD), + ("unshard", "layers.1", TrainingState.FORWARD), + ("reshard", "layers.1", TrainingState.FORWARD), + ("unshard", "layers.2", TrainingState.FORWARD), + ("reshard", "layers.2", TrainingState.FORWARD), + ("unshard", "layers.3", TrainingState.FORWARD), + ("reshard", "layers.3", TrainingState.FORWARD), + ] + with patch_unshard(unshard_with_record), patch_reshard( + reshard_with_record + ), patch_post_backward(post_backward_with_record): + set_backward_prefetch(model, num_to_prefetch=1) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` (same as default) + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + + set_backward_prefetch(model, num_to_prefetch=2) + loss = model(inp) + self.assertEqual(events, expected_forward_events) + events.clear() + loss.sum().backward() + expected_backward_events = [ + # Root prefetches `layers.3` per default + ("unshard", "layers.3", TrainingState.PRE_BACKWARD), + # `layers.i` prefetches for `layers.i-1` and `layers.i-2` + ("unshard", "layers.2", TrainingState.PRE_BACKWARD), + ("unshard", "layers.1", TrainingState.PRE_BACKWARD), + ("reshard", "layers.3", TrainingState.POST_BACKWARD), + ("post_backward", "layers.3", TrainingState.POST_BACKWARD), + ("unshard", "layers.0", TrainingState.PRE_BACKWARD), + ("reshard", "layers.2", TrainingState.POST_BACKWARD), + ("post_backward", "layers.2", TrainingState.POST_BACKWARD), + ("reshard", "layers.1", TrainingState.POST_BACKWARD), + ("post_backward", "layers.1", TrainingState.POST_BACKWARD), + ("reshard", "layers.0", TrainingState.POST_BACKWARD), + ("post_backward", "layers.0", TrainingState.POST_BACKWARD), + ("reshard", "", TrainingState.POST_BACKWARD), + ("post_backward", "", TrainingState.POST_BACKWARD), + ] + self.assertEqual(events, expected_backward_events) + events.clear() + def _init_transformer( self, n_layers: int, @@ -614,6 +802,21 @@ def unshard_with_record(self, *args, **kwargs): return unshard_with_record + def _get_reshard_with_record( + self, orig_reshard: Callable, events: List[EventType] + ) -> Callable: + def reshard_with_record(self, *args, **kwargs): + nonlocal events + if ( + self._training_state == TrainingState.FORWARD + and not self._reshard_after_forward + ): # skip no-ops + return + events.append(("reshard", self._module_fqn, self._training_state)) + return orig_reshard(self, *args, **kwargs) + + return reshard_with_record + def _get_post_backward_with_record( self, orig_post_backward: Callable, events: List[EventType] ) -> Callable: diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 836013f7fb2434..3dbaa652437940 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -3,6 +3,7 @@ import contextlib import copy import functools +import itertools import unittest from typing import Iterable, List, Tuple, Type, Union @@ -337,7 +338,6 @@ def _test_train_parity_multi_group( return assert device_type in ("cuda", "cpu"), f"{device_type}" torch.manual_seed(42) - lin_dim = 32 vocab_size = 1024 model_args = ModelArgs( n_layers=3, @@ -494,6 +494,44 @@ def forward(self, x): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_explicit_prefetching(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=8, dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + num_to_forward_prefetch = num_to_backward_prefetch = 2 + for i, layer in enumerate(model.layers): + if i >= len(model.layers) - num_to_forward_prefetch: + break + layers_to_prefetch = [ + model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1) + ] + layer.set_modules_to_forward_prefetch(layers_to_prefetch) + for i, layer in enumerate(model.layers): + if i < num_to_backward_prefetch: + continue + layers_to_prefetch = [ + model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1) + ] + layer.set_modules_to_backward_prefetch(layers_to_prefetch) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + for iter_idx in range(10): + losses: List[torch.Tensor] = [] + for _model, _optim in ((ref_model, ref_optim), (model, optim)): + _optim.zero_grad() + losses.append(_model(inp).sum()) + losses[-1].backward() + _optim.step() + self.assertEqual(losses[0], losses[1]) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 63142466f001fd..06fa90e060e70d 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -283,14 +283,15 @@ def _record_post_forward(self) -> None: self.comm_ctx.post_forward_order.append(self) self._post_forward_indices.append(post_forward_index) - def pre_backward(self, *unused: Any): + def pre_backward(self, default_prefetch: bool, *unused: Any): if self._training_state == TrainingState.PRE_BACKWARD: return with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD self.unshard() # no-op if prefetched self.wait_for_unshard() - self._prefetch_unshard() + if default_prefetch: + self._backward_prefetch() def post_backward(self, *unused: Any): self._training_state = TrainingState.POST_BACKWARD @@ -348,7 +349,7 @@ def finalize_backward(self): fsdp_param.grad_offload_event = None self._post_forward_indices.clear() - def _prefetch_unshard(self): + def _backward_prefetch(self) -> None: if self._training_state == TrainingState.PRE_BACKWARD: if not self._post_forward_indices: # Can be cleared if running multiple `backward`s @@ -360,11 +361,23 @@ def _prefetch_unshard(self): # have mistargeted prefetches if not all modules used in forward # are used in this backward target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] - target_fqn = target_fsdp_param_group._module_fqn - with record_function( - self._with_fqn(f"FSDP::backward_prefetch for {target_fqn}") - ), target_fsdp_param_group.use_training_state(TrainingState.PRE_BACKWARD): - target_fsdp_param_group.unshard() + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with record_function( + f"FSDP::{pass_type}_prefetch for {target_fqn}" + ), target_fsdp_param_group.use_training_state(training_state): + target_fsdp_param_group.unshard() # Utilities # def _to_sharded(self): diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index f080e755033842..79a09342704ff1 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -56,6 +56,8 @@ def __init__(self): self._state_ctx = FSDPStateContext() self._comm_ctx = FSDPCommContext() self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: List[FSDPState] = [] + self._states_to_backward_prefetch: List[FSDPState] = [] # Define a separate init since `__init__` is called in the contract def init( @@ -171,6 +173,9 @@ def _pre_forward( args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) if self._fsdp_param_group: args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") return args, kwargs @disable_if_config_true @@ -205,7 +210,11 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: self._training_state = TrainingState.PRE_BACKWARD self._register_root_post_backward_final_callback() if self._fsdp_param_group: - self._fsdp_param_group.pre_backward() + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") return grad def _root_post_backward_final_callback(self) -> None: diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index d3e70b38eac919..61b7878d467ff2 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools -from typing import Any, cast, NoReturn, Optional, Union +from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch import torch.nn as nn @@ -270,6 +270,46 @@ def set_reshard_after_backward( if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.reshard_after_backward = reshard_after_backward + def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap and + will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap and will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") @@ -350,3 +390,9 @@ def wrapped_method(self, *args, **kwargs): method_name, wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] ) + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 2b5fdc613c2e23..cfa16307da334d 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -997,6 +997,19 @@ def patch_unshard(new_unshard: Callable): FSDPParamGroup.unshard = orig_unshard +@no_type_check +@contextlib.contextmanager +def patch_reshard(new_reshard: Callable): + orig_reshard = FSDPParamGroup.reshard + dist.barrier() + FSDPParamGroup.reshard = new_reshard + try: + yield + finally: + dist.barrier() + FSDPParamGroup.reshard = orig_reshard + + @no_type_check @contextlib.contextmanager def patch_post_backward(new_post_backward: Callable): From e6d4451ae8987bf8d6ad85eb7cde685fac746f6f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 14:31:38 +0800 Subject: [PATCH 130/171] [BE][Easy] enable UFMT for `torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/` (#128866) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866 Approved by: https://github.com/fegin --- .lintrunner.toml | 66 -------------- torch/distributed/algorithms/__init__.py | 4 +- .../_checkpoint/checkpoint_wrapper.py | 11 ++- .../algorithms/_comm_hooks/__init__.py | 2 +- .../algorithms/_comm_hooks/default_hooks.py | 45 ++++++---- .../_optimizer_overlap/optimizer_overlap.py | 23 ++--- .../algorithms/_quantization/quantization.py | 53 ++++++------ .../algorithms/ddp_comm_hooks/__init__.py | 13 +-- .../ddp_comm_hooks/ddp_zero_hook.py | 55 +++++++----- .../ddp_comm_hooks/debugging_hooks.py | 1 + .../ddp_comm_hooks/default_hooks.py | 1 + .../ddp_comm_hooks/mixed_precision_hooks.py | 25 +++--- .../ddp_comm_hooks/optimizer_overlap_hooks.py | 33 ++++--- .../ddp_comm_hooks/post_localSGD_hook.py | 6 +- .../ddp_comm_hooks/powerSGD_hook.py | 47 +++++----- torch/distributed/algorithms/join.py | 26 +++--- .../algorithms/model_averaging/averagers.py | 21 +++-- .../hierarchical_model_averager.py | 22 +++-- .../algorithms/model_averaging/utils.py | 28 ++++-- torch/distributed/autograd/__init__.py | 18 ++-- .../benchmarks/benchmark_ddp_rpc.py | 13 ++- .../checkpoint/_dedup_save_plans.py | 1 + .../distributed/checkpoint/_dedup_tensors.py | 1 + .../checkpoint/_fsspec_filesystem.py | 1 + torch/distributed/checkpoint/_nested_dict.py | 1 + .../checkpoint/_sharded_tensor_utils.py | 1 + .../distributed/checkpoint/_storage_utils.py | 1 - torch/distributed/checkpoint/_traverse.py | 1 + torch/distributed/checkpoint/api.py | 1 + .../distributed/checkpoint/default_planner.py | 1 + .../examples/fsdp_checkpoint_example.py | 2 +- torch/distributed/checkpoint/filesystem.py | 2 +- torch/distributed/checkpoint/logger.py | 1 + torch/distributed/checkpoint/metadata.py | 1 + torch/distributed/checkpoint/optimizer.py | 1 + torch/distributed/checkpoint/planner.py | 1 - .../distributed/checkpoint/planner_helpers.py | 3 +- torch/distributed/checkpoint/resharding.py | 1 + torch/distributed/checkpoint/staging.py | 2 +- torch/distributed/checkpoint/state_dict.py | 1 + .../checkpoint/state_dict_loader.py | 1 + .../checkpoint/state_dict_saver.py | 1 - torch/distributed/checkpoint/storage.py | 2 +- torch/distributed/checkpoint/utils.py | 1 + torch/distributed/elastic/agent/server/api.py | 76 ++++++++++++----- .../agent/server/health_check_server.py | 1 + .../agent/server/local_elastic_agent.py | 59 ++++++++----- torch/distributed/elastic/control_plane.py | 1 + torch/distributed/elastic/events/__init__.py | 4 +- torch/distributed/elastic/events/api.py | 5 +- torch/distributed/elastic/metrics/__init__.py | 6 +- torch/distributed/elastic/metrics/api.py | 20 ++++- .../elastic/multiprocessing/__init__.py | 4 +- .../elastic/multiprocessing/api.py | 79 ++++++++++++----- .../multiprocessing/errors/__init__.py | 15 +++- .../multiprocessing/errors/error_handler.py | 26 +++--- .../multiprocessing/errors/handlers.py | 4 +- .../elastic/multiprocessing/redirects.py | 1 + .../subprocess_handler/__init__.py | 1 + .../subprocess_handler/handlers.py | 1 + .../subprocess_handler/subprocess_handler.py | 2 +- .../elastic/multiprocessing/tail_log.py | 8 +- .../elastic/rendezvous/__init__.py | 4 +- torch/distributed/elastic/rendezvous/api.py | 17 +++- .../rendezvous/c10d_rendezvous_backend.py | 16 ++-- .../elastic/rendezvous/dynamic_rendezvous.py | 47 ++++++---- .../elastic/rendezvous/etcd_rendezvous.py | 33 ++++--- .../rendezvous/etcd_rendezvous_backend.py | 11 ++- .../elastic/rendezvous/etcd_server.py | 1 + .../elastic/rendezvous/etcd_store.py | 4 +- .../elastic/rendezvous/registry.py | 11 ++- .../rendezvous/static_tcp_rendezvous.py | 7 +- torch/distributed/elastic/rendezvous/utils.py | 20 +++-- torch/distributed/elastic/timer/__init__.py | 14 ++- torch/distributed/elastic/timer/api.py | 25 ++++-- .../elastic/timer/debug_info_logging.py | 1 + .../elastic/timer/file_based_local_timer.py | 85 +++++++++++++------ .../distributed/elastic/timer/local_timer.py | 6 +- torch/distributed/elastic/utils/api.py | 2 +- .../distributed/elastic/utils/distributed.py | 13 ++- torch/distributed/elastic/utils/store.py | 13 +-- 81 files changed, 729 insertions(+), 456 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 08e434e8f143ba..2ea1579ee64c27 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1430,77 +1430,11 @@ exclude_patterns = [ 'torch/distributed/_sharding_spec/__init__.py', 'torch/distributed/_tools/__init__.py', 'torch/distributed/_tools/memory_tracker.py', - 'torch/distributed/algorithms/__init__.py', - 'torch/distributed/algorithms/_checkpoint/__init__.py', - 'torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py', - 'torch/distributed/algorithms/_comm_hooks/__init__.py', - 'torch/distributed/algorithms/_comm_hooks/default_hooks.py', - 'torch/distributed/algorithms/_optimizer_overlap/__init__.py', - 'torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py', - 'torch/distributed/algorithms/_quantization/__init__.py', - 'torch/distributed/algorithms/_quantization/quantization.py', - 'torch/distributed/algorithms/ddp_comm_hooks/__init__.py', - 'torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py', - 'torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py', - 'torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py', - 'torch/distributed/algorithms/join.py', - 'torch/distributed/algorithms/model_averaging/__init__.py', - 'torch/distributed/algorithms/model_averaging/averagers.py', - 'torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py', - 'torch/distributed/algorithms/model_averaging/utils.py', 'torch/distributed/argparse_util.py', - 'torch/distributed/autograd/__init__.py', - 'torch/distributed/benchmarks/benchmark_ddp_rpc.py', 'torch/distributed/c10d_logger.py', 'torch/distributed/collective_utils.py', 'torch/distributed/constants.py', 'torch/distributed/distributed_c10d.py', - 'torch/distributed/elastic/__init__.py', - 'torch/distributed/elastic/agent/__init__.py', - 'torch/distributed/elastic/agent/server/__init__.py', - 'torch/distributed/elastic/agent/server/api.py', - 'torch/distributed/elastic/agent/server/local_elastic_agent.py', - 'torch/distributed/elastic/events/__init__.py', - 'torch/distributed/elastic/events/api.py', - 'torch/distributed/elastic/events/handlers.py', - 'torch/distributed/elastic/metrics/__init__.py', - 'torch/distributed/elastic/metrics/api.py', - 'torch/distributed/elastic/multiprocessing/__init__.py', - 'torch/distributed/elastic/multiprocessing/api.py', - 'torch/distributed/elastic/multiprocessing/errors/__init__.py', - 'torch/distributed/elastic/multiprocessing/errors/error_handler.py', - 'torch/distributed/elastic/multiprocessing/errors/handlers.py', - 'torch/distributed/elastic/multiprocessing/redirects.py', - 'torch/distributed/elastic/multiprocessing/tail_log.py', - 'torch/distributed/elastic/rendezvous/__init__.py', - 'torch/distributed/elastic/rendezvous/api.py', - 'torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py', - 'torch/distributed/elastic/rendezvous/dynamic_rendezvous.py', - 'torch/distributed/elastic/rendezvous/etcd_rendezvous.py', - 'torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py', - 'torch/distributed/elastic/rendezvous/etcd_server.py', - 'torch/distributed/elastic/rendezvous/etcd_store.py', - 'torch/distributed/elastic/rendezvous/registry.py', - 'torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py', - 'torch/distributed/elastic/rendezvous/utils.py', - 'torch/distributed/elastic/timer/__init__.py', - 'torch/distributed/elastic/timer/api.py', - 'torch/distributed/elastic/timer/file_based_local_timer.py', - 'torch/distributed/elastic/timer/local_timer.py', - 'torch/distributed/elastic/utils/__init__.py', - 'torch/distributed/elastic/utils/api.py', - 'torch/distributed/elastic/utils/data/__init__.py', - 'torch/distributed/elastic/utils/data/cycling_iterator.py', - 'torch/distributed/elastic/utils/data/elastic_distributed_sampler.py', - 'torch/distributed/elastic/utils/distributed.py', - 'torch/distributed/elastic/utils/log_level.py', - 'torch/distributed/elastic/utils/logging.py', - 'torch/distributed/elastic/utils/store.py', 'torch/distributed/examples/memory_tracker_example.py', 'torch/distributed/launch.py', 'torch/distributed/launcher/__init__.py', diff --git a/torch/distributed/algorithms/__init__.py b/torch/distributed/algorithms/__init__.py index a07470a0cfd403..06c81429569940 100644 --- a/torch/distributed/algorithms/__init__.py +++ b/torch/distributed/algorithms/__init__.py @@ -1,3 +1 @@ -from .join import Join -from .join import Joinable -from .join import JoinHook +from .join import Join, Joinable, JoinHook diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 86ab1de003db42..8cc15f4aba3116 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -10,6 +10,7 @@ from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint + _CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" _CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." @@ -286,8 +287,12 @@ def apply_activation_checkpointing( """ # TODO: Importing inside function to avoid circular import issue between FSDP and # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. - from torch.distributed.fsdp.wrap import _recursive_wrap, lambda_auto_wrap_policy, _Policy from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply + from torch.distributed.fsdp.wrap import ( + _Policy, + _recursive_wrap, + lambda_auto_wrap_policy, + ) policy = ( auto_wrap_policy @@ -302,7 +307,9 @@ def apply_activation_checkpointing( target_module_to_kwargs = policy._run_policy( model, ignored_modules=set(), root_kwargs={} ) - wrap_fn = _construct_wrap_fn(model, target_module_to_kwargs, checkpoint_wrapper_fn) + wrap_fn = _construct_wrap_fn( + model, target_module_to_kwargs, checkpoint_wrapper_fn + ) _post_order_apply(model, wrap_fn) return diff --git a/torch/distributed/algorithms/_comm_hooks/__init__.py b/torch/distributed/algorithms/_comm_hooks/__init__.py index d07adc17247b71..7b57a075ad729d 100644 --- a/torch/distributed/algorithms/_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/_comm_hooks/__init__.py @@ -1,6 +1,6 @@ - from . import default_hooks as default + LOW_PRECISION_HOOKS = [ default.fp16_compress_hook, default.bf16_compress_hook, diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index d370fabafc3718..0acafd6868d3bb 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import functools +from typing import Optional + import torch import torch.distributed as dist -from typing import Optional class DefaultState: @@ -17,13 +18,10 @@ class DefaultState: "process_group", "world_size", "gradient_predivide_factor", - "gradient_postdivide_factor" + "gradient_postdivide_factor", ] - def __init__( - self, - process_group: dist.ProcessGroup - ): + def __init__(self, process_group: dist.ProcessGroup): if process_group is None: raise ValueError(f"Expected to pass in an explicit ProcessGroup to {self}.") self.process_group = process_group @@ -33,7 +31,9 @@ def __init__( self.gradient_predivide_factor = self._get_gradient_predivide_factor( self.world_size ) - self.gradient_postdivide_factor = self.world_size / self.gradient_predivide_factor + self.gradient_postdivide_factor = ( + self.world_size / self.gradient_predivide_factor + ) @staticmethod def _get_gradient_predivide_factor(world_size: int) -> float: @@ -42,6 +42,7 @@ def _get_gradient_predivide_factor(world_size: int) -> float: factor *= 2 return float(factor) + class LowPrecisionState(DefaultState): r""" Stores state needed to perform gradient communication in a lower precision within a communication hook. @@ -82,12 +83,15 @@ def _decompress(state: LowPrecisionState, grad: torch.Tensor): device_type = grad.device.type backend = getattr(torch, device_type) except AttributeError as e: - raise AttributeError(f"Device {grad.device} does not have a \ - corresponding backend registered as 'torch.device_type'.") from e + raise AttributeError( + f"Device {grad.device} does not have a \ + corresponding backend registered as 'torch.device_type'." + ) from e # Don't let this memory get reused until after the transfer. orig_grad_data.record_stream(backend.current_stream()) # type: ignore[arg-type] + def allreduce_hook(state: DefaultState, grad: torch.Tensor): r""" Implement the FSDP communication hook for ``all_reduce`` algorithm and a necessary pre- and post-division of gradients. @@ -106,6 +110,7 @@ def allreduce_hook(state: DefaultState, grad: torch.Tensor): if state.gradient_postdivide_factor > 1: grad.div_(state.gradient_postdivide_factor) + def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): r""" Implement the FSDP communication hook for ``reduce_scatter`` algorithm. @@ -121,14 +126,18 @@ def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.T # Average grad by pre-division factor. if state.gradient_predivide_factor > 1: grad.div_(state.gradient_predivide_factor) - dist.reduce_scatter_tensor( - output, grad, group=state.process_group - ) + dist.reduce_scatter_tensor(output, grad, group=state.process_group) # Average grad's shard by post-division factor. if state.gradient_postdivide_factor > 1: output.div_(state.gradient_postdivide_factor) -def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor): + +def _low_precision_hook( + prec: torch.dtype, + state: LowPrecisionState, + grad: torch.Tensor, + output: torch.Tensor, +): if grad.dtype != prec: grad.data = grad.data.to(prec) if output is not None: @@ -140,7 +149,10 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch allreduce_hook(state, grad) _decompress(state, grad) -def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): + +def fp16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): r""" Implement FSDP communication hook for a simple gradient compression approach. Casts ``grad`` to half-precision floating-point format (``torch.float16``). @@ -158,7 +170,10 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Opt fp16_hook = functools.partial(_low_precision_hook, torch.float16) return fp16_hook(state, grad, output) -def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): + +def bf16_compress_hook( + state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None +): r""" Implement FSDP communication hook for a simple gradient compression approach . Casts ``grad`` to half-precision floating-point format. diff --git a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py index 1afbb8d7967fc3..ada39ca24d9705 100644 --- a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py +++ b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -1,19 +1,18 @@ # mypy: allow-untyped-defs -from abc import ABC, abstractmethod import inspect +from abc import ABC, abstractmethod from typing import Dict, Type -from torch.distributed.fsdp import FullyShardedDataParallel -from torch.nn.parallel import DistributedDataParallel -from torch.optim import Optimizer -from torch.distributed.optim import as_functional_optim - from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook - from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( + _hook_then_optimizer, _OptimizerHookState, - _hook_then_optimizer ) +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.optim import as_functional_optim +from torch.nn.parallel import DistributedDataParallel +from torch.optim import Optimizer + # Contains the mappings between the regular and overlapped optimizer types. _registered_overlapped_optims: Dict[Type, Type] = {} @@ -29,6 +28,7 @@ def decorator(target_overlapped_optim_cls): ) _registered_overlapped_optims[optim_cls] = target_overlapped_optim_cls return target_overlapped_optim_cls + return decorator @@ -71,7 +71,7 @@ def register_ddp(self, ddp_inst: DistributedDataParallel): # yet supported. ddp_inst.register_comm_hook( # type: ignore[operator] None, # wrapped hook state - _hook_then_optimizer(allreduce_hook, self._opt_hook_state) + _hook_then_optimizer(allreduce_hook, self._opt_hook_state), ) # TODO: register_fsdp once FSDP supports communication hook. @@ -81,11 +81,14 @@ def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None: f"{self.__class__.__name__} does not support overlapped FSDP." ) + def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs): """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" for clz in inspect.getmro(optim_cls): try: - return _registered_overlapped_optims[clz](optim_cls, params, *args, **kwargs) + return _registered_overlapped_optims[clz]( + optim_cls, params, *args, **kwargs + ) except KeyError: pass diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index c421076bde3ecf..a579a0a02feae9 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -1,22 +1,23 @@ # mypy: allow-untyped-defs import functools +from enum import Enum + import torch import torch.distributed as dist -from enum import Enum - - TORCH_HALF_MIN = torch.finfo(torch.float16).min TORCH_HALF_MAX = torch.finfo(torch.float16).max + class DQuantType(Enum): """ Different quantization methods for auto_quantize API are identified here. auto_quantize API currently supports fp16 and bfp16 methods. """ - FP16 = "fp16", + + FP16 = ("fp16",) BFP16 = "bfp16" def __str__(self) -> str: @@ -26,6 +27,7 @@ def __str__(self) -> str: def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() + def _quantize_tensor(tensor, qtype): if not isinstance(tensor, torch.Tensor): raise RuntimeError( @@ -36,9 +38,8 @@ def _quantize_tensor(tensor, qtype): elif qtype == DQuantType.BFP16: return torch.ops.quantization._FloatToBfloat16Quantized(tensor) else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + raise RuntimeError(f"Quantization type {qtype} is not supported") + def _quantize_tensor_list(tensor_list, qtype): if not isinstance(tensor_list, list) or not all( @@ -50,6 +51,7 @@ def _quantize_tensor_list(tensor_list, qtype): quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] return quantized_tensor_list + def _dequantize_tensor(tensor, qtype, quant_loss=None): if not isinstance(tensor, torch.Tensor): raise RuntimeError( @@ -72,9 +74,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None): else: return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) else: - raise RuntimeError( - f'Quantization type {qtype} is not supported' - ) + raise RuntimeError(f"Quantization type {qtype} is not supported") def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): @@ -103,20 +103,21 @@ def auto_quantize(func, qtype, quant_loss=None): Returns: (Callable): the same collective as func but enables automatic quantization/dequantization. """ + @functools.wraps(func) def wrapper(*args, **kwargs): - group = kwargs.get('group', None) - async_op = kwargs.get('async_op', False) + group = kwargs.get("group", None) + async_op = kwargs.get("async_op", False) if async_op is True: - raise RuntimeError( - 'The async_op=True mode is not supported yet.' - ) + raise RuntimeError("The async_op=True mode is not supported yet.") if func == dist.all_gather: tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) - for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t elif func == dist.all_to_all: @@ -124,22 +125,26 @@ def wrapper(*args, **kwargs): input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) - for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): + for i, t in enumerate( + _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t elif func == dist.all_to_all_single: tensors = args[0] - out_splits = kwargs.get('out_splits', None) - in_splits = kwargs.get('in_splits', None) + out_splits = kwargs.get("out_splits", None) + in_splits = kwargs.get("in_splits", None) # Quantizing the input/output tensor input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor(tensors, qtype) - dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) - for i, t in enumerate(_dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): + dist.all_to_all_single( + out_tensors, input_tensors, out_splits, in_splits, group=group + ) + for i, t in enumerate( + _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) + ): tensors[i] = t else: - raise RuntimeError( - f"The collective op {func} is not supported yet" - ) + raise RuntimeError(f"The collective op {func} is not supported yet") return wrapper diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py index 2366a9d28c1388..a1d1ffd2fc8771 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -7,12 +7,14 @@ from . import ( debugging_hooks as debugging, default_hooks as default, + optimizer_overlap_hooks as optimizer_overlap, powerSGD_hook as powerSGD, quantization_hooks as quantization, - optimizer_overlap_hooks as optimizer_overlap, ) -__all__ = ['DDPCommHookType', 'register_ddp_comm_hook'] + +__all__ = ["DDPCommHookType", "register_ddp_comm_hook"] + def _ddp_comm_hook_wrapper(comm_hook, model, state): model.register_comm_hook(state, comm_hook) @@ -86,13 +88,12 @@ class DDPCommHookType(Enum): matrix_approximation_rank=2, ) NOOP = partial( - _ddp_comm_hook_wrapper, comm_hook=debugging.noop_hook, + _ddp_comm_hook_wrapper, + comm_hook=debugging.noop_hook, ) -def register_ddp_comm_hook( - comm_hook_type: DDPCommHookType, model, state=None -): +def register_ddp_comm_hook(comm_hook_type: DDPCommHookType, model, state=None): """ Register ``ddp_comm_hooks`` to DDP model. diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 8ab58cb584421a..6db6d1831b1fde 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -5,11 +5,10 @@ import torch import torch.distributed as dist from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.distributed.optim.zero_redundancy_optimizer import ( - _OverlapStatus, -) +from torch.distributed.optim.zero_redundancy_optimizer import _OverlapStatus from torch.nn.parallel.distributed import DistributedDataParallel + __all__ = ["hook_with_zero_step", "hook_with_zero_step_interleaved"] # Functional optimizers require passing a list of gradients to their `step()` @@ -39,22 +38,25 @@ def _perform_local_step( """ overlap_info = zero._overlap_info bucket_index = bucket.index() - assert len(zero.optim.param_groups) == 1, \ - "Overlapping DDP with ZeRO only supports a single parameter group" + assert ( + len(zero.optim.param_groups) == 1 + ), "Overlapping DDP with ZeRO only supports a single parameter group" # Construct the `gradients` input for the local optimizer step, which # expects `None` in a list position to indicate that the corresponding # parameter should not be updated num_local_optim_params = len(zero.optim.param_groups[0]["params"]) - gradients: List[Optional[torch.Tensor]] = \ - [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)] - assert bucket_index in overlap_info.offsets, \ - f"Bucket index {bucket_index} was not assigned to rank {rank}" + gradients: List[Optional[torch.Tensor]] = [ + _NO_PARAM_UPDATE for _ in range(num_local_optim_params) + ] + assert ( + bucket_index in overlap_info.offsets + ), f"Bucket index {bucket_index} was not assigned to rank {rank}" gradients_offset = overlap_info.offsets[bucket_index] bucket_assignment = zero._bucket_assignments_per_rank[rank][bucket_index] bucket_offset = bucket_assignment.offset length = len(bucket_assignment.parameters) - bucket_gradients = bucket.gradients()[bucket_offset:bucket_offset + length] + bucket_gradients = bucket.gradients()[bucket_offset : bucket_offset + length] for i, grad in enumerate(bucket_gradients): gradients[gradients_offset + i] = grad @@ -75,12 +77,14 @@ def _broadcast_bucket( :class:`ZeroRedundancyOptimizer` instance. """ overlap_info = zero._overlap_info - assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ - "`assigned_ranks_per_bucket` is not fully constructed" + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" # Sort to ensure the same ordering across ranks assigned_ranks = sorted(overlap_info.assigned_ranks_per_bucket[bucket_index]) - assert len(assigned_ranks) > 0, f"Bucket {bucket_index} should be " \ - "assigned to at least one rank" + assert len(assigned_ranks) > 0, ( + f"Bucket {bucket_index} should be " "assigned to at least one rank" + ) for assigned_rank in assigned_ranks: bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank] if bucket_index in bucket_assignments: @@ -229,7 +233,7 @@ def hook_with_zero_step( # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): + if (pg != dist.Backend.NCCL) and (pg != "hccl"): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL/HCCL backend to avoid hangs" @@ -267,9 +271,12 @@ def hook_with_zero_fn( rank = zero.global_rank assert overlap_info.status == _OverlapStatus.INITIALIZED - assert len(overlap_info.assigned_ranks_per_bucket) > bucket_index, \ - "`assigned_ranks_per_bucket` is not fully constructed" - assigned_to_bucket = rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + assert ( + len(overlap_info.assigned_ranks_per_bucket) > bucket_index + ), "`assigned_ranks_per_bucket` is not fully constructed" + assigned_to_bucket = ( + rank in overlap_info.assigned_ranks_per_bucket[bucket_index] + ) # Save the bucket reference and all-reduce future for the final bucket if assigned_to_bucket: @@ -279,8 +286,9 @@ def hook_with_zero_fn( # Check that buckets are indexed incrementally starting from 0 in the # order of their autograd hooks firing if len(overlap_info.bucket_indices_seen) > 0: - assert overlap_info.bucket_indices_seen[-1] == bucket_index - 1, \ - "Bucket indices are not in incremental order" + assert ( + overlap_info.bucket_indices_seen[-1] == bucket_index - 1 + ), "Bucket indices are not in incremental order" else: assert bucket_index == 0, "Bucket indices do not start from 0" overlap_info.bucket_indices_seen.append(bucket_index) @@ -302,9 +310,10 @@ def hook_with_zero_fn( if rank in assigned_ranks: # Wait on the bucket's all-reduce future to ensure correct # gradients - assert bucket_index in overlap_info.bucket_index_to_future, \ - f"All-reduce future for bucket {bucket_index} not saved " \ + assert bucket_index in overlap_info.bucket_index_to_future, ( + f"All-reduce future for bucket {bucket_index} not saved " f"on rank {rank}" + ) allreduce_future = overlap_info.bucket_index_to_future[bucket_index] allreduce_future.wait() @@ -386,7 +395,7 @@ def hook_with_zero_step_interleaved( # NOTE: Gloo may hang with this overlapping approach, so we require # NCCL/HCCL backend for now; see https://github.com/pytorch/pytorch/issues/62300 pg = dist.get_backend(ddp_ref().process_group) # type: ignore[union-attr] - if ((pg != dist.Backend.NCCL) and (pg != 'hccl')): + if (pg != dist.Backend.NCCL) and (pg != "hccl"): raise RuntimeError( "Overlapping DDP with ZeRO using this approach currently requires " "NCCL/HCCL backend to avoid hangs" diff --git a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py index a552f9a359f7ed..53a184839a06f4 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py @@ -3,6 +3,7 @@ import torch from torch.distributed import GradBucket + __all__ = ["noop_hook"] diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index 621e46fc198963..b1296ae712f0c0 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist + __all__ = [ "allreduce_hook", "fp16_compress_hook", diff --git a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py index 31b243d44e0fd2..4727bbf9d45e66 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -1,11 +1,12 @@ +from dataclasses import dataclass +from typing import Any, no_type_check + import torch import torch.distributed as dist from torch.autograd import Variable - -from dataclasses import dataclass -from typing import Any, no_type_check from torch.distributed.utils import _free_storage + @dataclass class _AllreduceUpcastHookState: """ @@ -19,6 +20,7 @@ class _AllreduceUpcastHookState: upcast_stream: torch.cuda.Stream wait_for_stream_enqueued: bool = False + @no_type_check def _reducer_allreduce_and_upcast_hook( hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket @@ -35,10 +37,13 @@ def _reducer_allreduce_and_upcast_hook( gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view # Cast bucket if different than param_dtype. if ( - ddp_weakref().mixed_precision.param_dtype != ddp_weakref().mixed_precision.reduce_dtype + ddp_weakref().mixed_precision.param_dtype + != ddp_weakref().mixed_precision.reduce_dtype ): # Cast bucket tensor to reduce_dtype - bucket.set_buffer(bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype)) + bucket.set_buffer( + bucket.buffer().to(ddp_weakref().mixed_precision.reduce_dtype) + ) fut = reducer._run_allreduce_hook(bucket) ret_fut = torch.futures.Future() stream = hook_state.upcast_stream @@ -66,19 +71,17 @@ def wait_for_stream_cb(): # by hook above as they don't have a grad hook installed, so cast them # back here. for n, p in ddp_weakref().module.named_parameters(): - if hasattr(p, '_ddp_mp_hook_state'): + if hasattr(p, "_ddp_mp_hook_state"): p._ddp_mp_hook_state[1].remove() - delattr(p, '_ddp_mp_hook_state') - if not p.requires_grad and not hasattr(p, '_ddp_ignored'): + delattr(p, "_ddp_mp_hook_state") + if not p.requires_grad and not hasattr(p, "_ddp_ignored"): p.data = p._fp_param # reset for next backward pass hook_state.wait_for_stream_enqueued = False if not hook_state.wait_for_stream_enqueued: - Variable._execution_engine.queue_callback( - wait_for_stream_cb - ) + Variable._execution_engine.queue_callback(wait_for_stream_cb) # mark that the callback is enqueued hook_state.wait_for_stream_enqueued = True diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index 76d4cd6de2bdc8..5ae242b04a9c53 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -1,16 +1,18 @@ # mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import partial from typing import Any, Callable, List, no_type_check import torch import torch.distributed as dist from torch.autograd import Variable -from functools import partial -from dataclasses import dataclass + __all__: List[str] = [] _FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" + class _OptimizerHookState: """ Holds state for running optimizer in-line after DDP communication hook. @@ -42,9 +44,10 @@ class _OptimInBackwardHookState: optim_stream: torch.cuda.Stream wait_for_optim_stream_enqueued: bool + @no_type_check def _apply_optim_in_backward_hook( - gradient_is_bucket_view: bool + gradient_is_bucket_view: bool, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r""" Register hook to apply the optimizer in backward. @@ -59,7 +62,9 @@ def _apply_optim_in_backward_hook( ) def apply_optim_in_backward_hook( - hook_state: Any, bucket: dist.GradBucket, optim_stream_state, + hook_state: Any, + bucket: dist.GradBucket, + optim_stream_state, ) -> torch.futures.Future[torch.Tensor]: # Run original hook ddp_weakref = hook_state @@ -78,7 +83,7 @@ def apply_optim_in_backward_hook( # TODO (rohan-varma): upcast as needed for DDP mixed precision, # once optimizer in backward + DDP mixed precision is supported. for p, g in zip(model_params, grads): - if hasattr(p, '_in_backward_optimizers'): + if hasattr(p, "_in_backward_optimizers"): # Note: need to set grad to the bucket's grad, because # running allreduce results in the bucket's grad being # reduced, but not grad field. @@ -94,21 +99,17 @@ def apply_optim_in_backward_hook( # enqueue a callback to wait for this optimizer stream at the end of # backward and set all DDP managed grads to None. def wait_for_optim_stream_callback(): - torch.cuda.current_stream().wait_stream( - optim_stream_state.optim_stream - ) + torch.cuda.current_stream().wait_stream(optim_stream_state.optim_stream) # Set DDP managed grads to None for param in ddp_inst._get_data_parallel_params(ddp_inst.module): - if hasattr(param, '_in_backward_optimizers'): + if hasattr(param, "_in_backward_optimizers"): param.grad = None # reset for the next backwards pass optim_stream_state.wait_for_optim_stream_enqueued = False if not optim_stream_state.wait_for_optim_stream_enqueued: - Variable._execution_engine.queue_callback( - wait_for_optim_stream_callback - ) + Variable._execution_engine.queue_callback(wait_for_optim_stream_callback) # mark that the callback is enqueued optim_stream_state.wait_for_optim_stream_enqueued = True @@ -123,13 +124,14 @@ def wait_for_optim_stream_callback(): return comm_hook + def _hook_then_optimizer( hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]], optimizer_state: _OptimizerHookState, ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: r"""Run optimizer in a functional fashion after DDP communication hook.""" has_set_params = ( - hasattr(optimizer_state, 'params_to_optimize') + hasattr(optimizer_state, "params_to_optimize") and optimizer_state.params_to_optimize is not None ) @@ -143,7 +145,10 @@ def optimizer_step(fut): gradient_tensors = bucket.gradients() model_params = bucket.parameters() for grad_tensor, model_param in zip(gradient_tensors, model_params): - if not has_set_params or model_param in optimizer_state.params_to_optimize: + if ( + not has_set_params + or model_param in optimizer_state.params_to_optimize + ): optimizer_state.functional_optimizer.step_param( model_param, grad_tensor, diff --git a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py index 3528f3987479fa..d8da01e6e1fe2e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py @@ -6,6 +6,7 @@ from . import default_hooks as default + logger = logging.getLogger(__name__) @@ -62,9 +63,8 @@ def maybe_increase_iter(self, bucket): self.iter += 1 if self.iter == self.start_localSGD_iter: - logger.info( - "Start to apply local SGD after %s iterations.", self.iter - ) + logger.info("Start to apply local SGD after %s iterations.", self.iter) + def post_localSGD_hook( state: PostLocalSGDState, bucket: dist.GradBucket diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index fbc3b9e8739e46..96b3b888511ae0 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -1,18 +1,17 @@ # mypy: allow-untyped-defs -from collections import defaultdict import logging import math +from collections import defaultdict from typing import Dict import torch import torch.distributed as dist +from torch.distributed import distributed_c10d from . import default_hooks as default -from torch.distributed import distributed_c10d -__all__ = [ - "PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook" -] + +__all__ = ["PowerSGDState", "powerSGD_hook", "batched_powerSGD_hook"] logger = logging.getLogger(__name__) @@ -35,10 +34,13 @@ def _orthogonalize(matrices, epsilon=0): matrices, out=( matrices, - torch.empty(num_matrices, rank, rank, device=matrices.device, dtype=dtype) - ) + torch.empty( + num_matrices, rank, rank, device=matrices.device, dtype=dtype + ), + ), ) + def _orthogonalize_gram_schmidt(matrices, epsilon=0): """ Apply Gram-Schmidt procedure to orthogonalize a batch of matrices. @@ -103,14 +105,15 @@ def _should_compress( def _report_compression_stats(bucket, state): """Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state.""" - if ( - bucket.is_last() - and state.iter >= state.next_stats_report - ): + if bucket.is_last() and state.iter >= state.next_stats_report: stats = state.compression_stats() logger.info( "Compression stats: iter %s, total before compression %s, total after compression %s, " - "rate %s", state.iter, stats[1], stats[2], stats[0] + "rate %s", + state.iter, + stats[1], + stats[2], + stats[0], ) state.next_stats_report = state.iter + state.compression_stats_logging_frequency @@ -244,6 +247,7 @@ def __init__( # If the same random projection is used, # there will be differences between the gradients that are never synchronized. import numpy as np + self.rng = np.random.RandomState(random_seed) # Since there is only a single state instance for all the input buckets, # need to maintain a dictionary that maps each bucket index to the local error. @@ -280,7 +284,8 @@ def __getstate__(self): ) return { slot: getattr(self, slot) - for slot in self.__slots__ if slot != "process_group" + for slot in self.__slots__ + if slot != "process_group" } def __setstate__(self, state): @@ -305,9 +310,7 @@ def maybe_increase_iter(self, bucket): self.iter += 1 if self.iter == self.start_powerSGD_iter: - logger.info( - "Start to apply PowerSGD after %s iterations.", self.iter - ) + logger.info("Start to apply PowerSGD after %s iterations.", self.iter) def compression_stats(self): r""" @@ -420,7 +423,7 @@ def powerSGD_hook( else: logger.info( "A zero tensor of length %s that represents local error is created.", - total_length + total_length, ) state.error_dict[bucket_index] = torch.zeros( total_length, device=device, dtype=dtype @@ -478,7 +481,8 @@ def powerSGD_hook( if state.warm_start: logger.info( "Allocating contiguous memory of length %s for Ps, and of length %s for Qs, respectively.", - total_Ps_size, total_Qs_size + total_Ps_size, + total_Qs_size, ) state.p_memory_dict[bucket_index] = torch.empty( total_Ps_size, device=device, dtype=dtype @@ -724,7 +728,7 @@ def batched_powerSGD_hook( state.total_numel_after_compression += ( square_side_length * state.matrix_approximation_rank * 2 ) - padded_total_length = square_side_length ** 2 + padded_total_length = square_side_length**2 input_tensor.resize_(padded_total_length) input_tensor[total_length:padded_total_length].fill_(0) @@ -739,7 +743,7 @@ def batched_powerSGD_hook( else: logger.info( "A zero tensor of length %s that represents local error is created.", - padded_total_length + padded_total_length, ) state.error_dict[bucket_index] = torch.zeros( padded_total_length, device=device, dtype=input_tensor.dtype @@ -759,7 +763,8 @@ def batched_powerSGD_hook( if state.warm_start: logger.info( "Initializing low-rank tensors P and Q, each of which has a shape of %s x %s.", - square_side_length, state.matrix_approximation_rank + square_side_length, + state.matrix_approximation_rank, ) def create_low_rank_tensor(fill_random_values, rng): diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 2936747a1c6ece..140844851938b7 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -7,7 +7,9 @@ import torch import torch.distributed as dist -__all__ = ['JoinHook', 'Joinable', 'Join'] + +__all__ = ["JoinHook", "Joinable", "Join"] + class JoinHook: r""" @@ -97,13 +99,10 @@ def construct_disabled_join_config(): e.g. if the caller is not in a join context manager. """ return _JoinConfig( - enable=False, - throw_on_early_termination=False, - is_first_joinable=False + enable=False, throw_on_early_termination=False, is_first_joinable=False ) - class Join: r""" This class defines the generic join context manager, which allows custom hooks to be called after a process joins. @@ -176,7 +175,9 @@ def __init__( if len(joinables) == 0: raise ValueError("The join context manager requires at least one joinable") self._joinables = joinables - self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables] + self._join_hooks = [ + joinable.join_hook(**kwargs) for joinable in self._joinables + ] self._enable = enable self._throw_on_early_termination = throw_on_early_termination self._set_joinable_configs() @@ -190,7 +191,7 @@ def _set_joinable_configs(self) -> None: joinable._join_config = _JoinConfig( enable=self._enable, throw_on_early_termination=self._throw_on_early_termination, - is_first_joinable=is_first_joinable + is_first_joinable=is_first_joinable, ) is_first_joinable = False @@ -215,7 +216,9 @@ def _extract_dist_info(self) -> None: if process_group is None: process_group = joinable.join_process_group elif process_group != joinable.join_process_group: - raise ValueError("Using join context manager with multiple process groups") + raise ValueError( + "Using join context manager with multiple process groups" + ) if device is None: device = joinable.join_device self._process_group = process_group @@ -229,7 +232,7 @@ def __exit__( self, type: Optional[Type[BaseException]], value: Optional[BaseException], - traceback: Optional[TracebackType] + traceback: Optional[TracebackType], ): r""" Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. @@ -318,9 +321,10 @@ def notify_join_context(joinable: Joinable): manager that the process has not yet joined if ``joinable`` is the first one passed into the context manager; ``None`` otherwise. """ - assert hasattr(joinable, "_join_config"), \ - f"Check that the {type(joinable)} constructor calls the " \ + assert hasattr(joinable, "_join_config"), ( + f"Check that the {type(joinable)} constructor calls the " "``Joinable`` constructor" + ) join_config = joinable._join_config # First joinable is responsible for the collective communications diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index 178efd1dbad92f..e15154e3f8578f 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -1,12 +1,15 @@ # mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod -from typing import Union, Iterable, Dict +from typing import Dict, Iterable, Union + import torch import torch.distributed as dist import torch.distributed.algorithms.model_averaging.utils as utils -__all__ = ['ModelAverager', 'PeriodicModelAverager'] + +__all__ = ["ModelAverager", "PeriodicModelAverager"] + class ModelAverager(ABC): r"""Base class for all model averagers. @@ -82,12 +85,7 @@ class PeriodicModelAverager(ModelAverager): >>> averager.average_parameters(model.parameters()) """ - def __init__( - self, - period, - warmup_steps=0, - process_group=None - ): + def __init__(self, period, warmup_steps=0, process_group=None): super().__init__(process_group) if warmup_steps < 0: raise ValueError("Arg ``warmup_steps`` must be a non-negative number.") @@ -103,7 +101,12 @@ def __init__( ) self.period = period - def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + ): """ Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``. diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index 02802466ab62a2..a27f3b762a9e3a 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -3,13 +3,14 @@ import logging import warnings from collections import OrderedDict -from typing import Union, Iterable, Dict +from typing import Dict, Iterable, Union import torch import torch.distributed as dist import torch.distributed.algorithms.model_averaging.averagers as averagers import torch.distributed.algorithms.model_averaging.utils as utils + logger = logging.getLogger(__name__) @@ -103,7 +104,9 @@ def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=No raise ValueError("Arg ``period_group_size_dict`` must not be empty.") self._periods = list(period_group_size_dict.keys()) if self._periods[0] <= 0: - raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.") + raise ValueError( + "The minimum period in arg ``period_group_size_dict`` must be a positive value." + ) elif self._periods[-1] == 1: warnings.warn( "When the maximum period in arg ``period_group_size_dict`` is 1, " @@ -124,10 +127,14 @@ def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=No for period, group_size in period_group_size_dict.items(): logger.info( "\tEach group that has %s processes average parameters every %s iterations, " - "if no higher-level averaging.", group_size, period) + "if no higher-level averaging.", + group_size, + period, + ) if group_size != overall_group_size: self.period_process_group_dict[period], _ = dist.new_subgroups( - group_size=group_size, group=self.process_group) + group_size=group_size, group=self.process_group + ) else: self.period_process_group_dict[period] = self.process_group @@ -149,7 +156,12 @@ def _find_process_group(self): return self.period_process_group_dict[period] return None - def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): + def average_parameters( + self, + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + ): """ Averages parameters or parameter groups of an optimizer. diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index de1977959d21c1..20f75152f0b876 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,16 +1,23 @@ # mypy: allow-untyped-defs # flake8: noqa C101 import itertools -from typing import Union, Iterable, Dict, Iterator +from typing import Dict, Iterable, Iterator, Union import torch import torch.distributed as dist + # The two imports below are not always available depending on the # USE_DISTRIBUTED compile flag. Make sure they raise import error # if we're trying to use them. -from torch.distributed import ProcessGroup, group +from torch.distributed import group, ProcessGroup + + +__all__ = [ + "average_parameters", + "get_params_to_average", + "average_parameters_or_parameter_groups", +] -__all__ = ["average_parameters", "get_params_to_average", "average_parameters_or_parameter_groups"] def average_parameters( params: Iterator[torch.nn.Parameter], process_group: ProcessGroup @@ -43,7 +50,9 @@ def average_parameters( offset += p.numel() -def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]): +def get_params_to_average( + params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]] +): """ Return a list of parameters that need to average. @@ -64,10 +73,17 @@ def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[D if param_data.grad is not None: filtered_params.append(param_data) else: - raise NotImplementedError(f"Parameter input of type {type(param)} is not supported") + raise NotImplementedError( + f"Parameter input of type {type(param)} is not supported" + ) return filtered_params -def average_parameters_or_parameter_groups(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]], process_group: ProcessGroup): +def average_parameters_or_parameter_groups( + params: Union[ + Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + ], + process_group: ProcessGroup, +): """Averages parameters of a model or parameter groups of an optimizer.""" average_parameters(iter(get_params_to_average(params)), process_group) diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index 6546c38a37b99b..b1cf0aec6140f8 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs -import sys import torch @@ -13,22 +12,22 @@ def is_available(): if is_available(): from torch._C._distributed_autograd import ( - get_gradients, - backward, + _current_context, + _get_debug_info, + _get_max_id, _init, + _is_valid_context, _new_context, _release_context, - _get_max_id, - _is_valid_context, _retrieve_context, - _current_context, - _get_debug_info, + backward, DistAutogradContext, + get_gradients, ) class context: - ''' + """ Context object to wrap forward and backward passes when using distributed autograd. The ``context_id`` generated in the ``with`` statement is required to uniquely identify a distributed backward pass @@ -44,7 +43,8 @@ class context: >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss]) - ''' + """ + def __enter__(self): self.autograd_context = _new_context() return self.autograd_context._context_id() diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 60f71e12213be9..d36568953d6b58 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -8,12 +8,13 @@ import time import numpy as np + import torch -import torch.nn as nn import torch.distributed as dist import torch.distributed.autograd as dist_autograd import torch.distributed.rpc as rpc import torch.multiprocessing as mp +import torch.nn as nn import torch.optim as optim from torch.distributed.optim import DistributedOptimizer from torch.distributed.rpc import RRef, TensorPipeRpcBackendOptions @@ -210,14 +211,13 @@ def run_worker(rank, world_size): # Rank 16. Master if rank == (NUM_TRAINERS + NUM_PS): - rpc.init_rpc( - "master", rank=rank, + "master", + rank=rank, backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] - world_size=world_size + world_size=world_size, ) - # Build the Embedding tables on the Parameter Servers. emb_rref_list = [] index = 0 @@ -256,7 +256,6 @@ def run_worker(rank, world_size): # Rank 0-7. Trainers elif rank >= 0 and rank < NUM_PS: - # Initialize process group for Distributed DataParallel on trainers. dist.init_process_group( backend=dist.Backend.GLOO, @@ -292,7 +291,7 @@ def run_worker(rank, world_size): if __name__ == "__main__": - """ Initializing the distributed environment. """ + """Initializing the distributed environment.""" output = _run_printable("nvidia-smi topo -m") print("-------------------------------------------") diff --git a/torch/distributed/checkpoint/_dedup_save_plans.py b/torch/distributed/checkpoint/_dedup_save_plans.py index 16d46e73baffd0..dd37634a0aa64b 100644 --- a/torch/distributed/checkpoint/_dedup_save_plans.py +++ b/torch/distributed/checkpoint/_dedup_save_plans.py @@ -5,6 +5,7 @@ from torch.distributed.checkpoint.planner import SavePlan, WriteItem + if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import MetadataIndex diff --git a/torch/distributed/checkpoint/_dedup_tensors.py b/torch/distributed/checkpoint/_dedup_tensors.py index 7689b9452e8ccb..687afb287b3c74 100644 --- a/torch/distributed/checkpoint/_dedup_tensors.py +++ b/torch/distributed/checkpoint/_dedup_tensors.py @@ -5,6 +5,7 @@ from torch.distributed.checkpoint.planner import SavePlan + if TYPE_CHECKING: from torch.distributed.checkpoint.metadata import MetadataIndex diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index 7fdd04dff311cd..b57df9c3456ca6 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -17,6 +17,7 @@ FileSystemWriter, ) + __all__ = [ "FsspecWriter", "FsspecReader", diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py index 527a67e6892fe5..3347ea8bc432ae 100644 --- a/torch/distributed/checkpoint/_nested_dict.py +++ b/torch/distributed/checkpoint/_nested_dict.py @@ -5,6 +5,7 @@ from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict + """ TODO: Need to add ability to handle tuple, OrderedDict, NamedTuple. diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py index f71f129e127c76..a68bcddeb7f9d9 100644 --- a/torch/distributed/checkpoint/_sharded_tensor_utils.py +++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -11,6 +11,7 @@ from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict from .utils import _element_wise_add, _normalize_device_info + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata diff --git a/torch/distributed/checkpoint/_storage_utils.py b/torch/distributed/checkpoint/_storage_utils.py index 0f5205a1f20305..194c9c8c4b9b15 100644 --- a/torch/distributed/checkpoint/_storage_utils.py +++ b/torch/distributed/checkpoint/_storage_utils.py @@ -2,7 +2,6 @@ from typing import List, Type, Union from .filesystem import FileSystemReader, FileSystemWriter - from .storage import StorageReader, StorageWriter diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 5d5e87bf130877..8bcb832c71980f 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -17,6 +17,7 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + PATH_ITEM = Union[str, int] OBJ_PATH = Tuple[PATH_ITEM, ...] T = TypeVar("T") diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 660196bc28de8b..e587580617a1bf 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -2,6 +2,7 @@ import traceback as tb from typing import Any, Dict, Tuple + WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] __all__ = ["CheckpointException"] diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 83b76718a6b7a8..cbf855d51417f8 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -46,6 +46,7 @@ ) from torch.distributed.checkpoint.utils import find_state_dict_object + logger: logging.Logger = logging.getLogger(__name__) diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py index 38c637d3a4fd1a..f2f03840b0d576 100644 --- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py +++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -16,10 +16,10 @@ import torch.distributed.checkpoint as dist_cp import torch.multiprocessing as mp from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint" diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 4d512891f12231..859476d71e16f6 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -32,7 +32,6 @@ from torch import Tensor from torch._utils import _get_available_device_type, _get_device_module from torch.distributed._shard._utils import narrow_tensor_by_index - from torch.distributed.checkpoint.metadata import ( Metadata, MetadataIndex, @@ -58,6 +57,7 @@ from torch.distributed.checkpoint.utils import _create_file_view from torch.futures import Future + __all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] _metadata_fn: str = ".metadata" diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 270240490c99da..c210819ec5ad7c 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -7,6 +7,7 @@ import torch.distributed.c10d_logger as c10d_logger from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME + __all__: List[str] = [] global _dcp_logger diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py index b3bc7a580dad0a..d1f87e2d9cba8a 100644 --- a/torch/distributed/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -7,6 +7,7 @@ import torch from torch.distributed.checkpoint.stateful import StatefulT + __all__ = [ "ChunkStorageMetadata", "TensorStorageMetadata", diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 26468d046f29a5..220ca22f703e52 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -40,6 +40,7 @@ from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device + STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py index 5eec8bf7546651..d3e79950e0c762 100644 --- a/torch/distributed/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -7,7 +7,6 @@ from typing import Any, List, Optional, Tuple, Union import torch - from torch.distributed.checkpoint.metadata import ( ChunkStorageMetadata, Metadata, diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 4bbe26876c881c..56e17281d4e4b5 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -4,13 +4,11 @@ import torch import torch.distributed as dist from torch._utils import _get_device_module - from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed.checkpoint.planner import _Checkpointable - from torch.utils._pytree import tree_map_only from .metadata import ( @@ -35,6 +33,7 @@ _shards_get_overlap_region_wrt_saved_tensor, ) + __all__: List[str] = ["create_read_items_for_chunk_list"] diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py index a1bf112f179506..0e5153df8da0ad 100644 --- a/torch/distributed/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -3,6 +3,7 @@ from torch.distributed.checkpoint.metadata import ChunkStorageMetadata + __all__: List[str] = [] diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index dba7ea0b413615..40f2fbdf0a0d95 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -6,9 +6,9 @@ _create_cpu_state_dict, _offload_state_dict_to_cpu, ) - from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE + __all__ = ["AsyncStager", "BlockingAsyncStager"] diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index c906ff3dcc202e..16a1ddde215869 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -54,6 +54,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils._pytree import tree_map_only + __all__ = [ "FQNS_T", "PrimitiveType", diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index f443f73f02d6da..c4d1d853e9c662 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -16,6 +16,7 @@ from .storage import StorageReader from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile + __all__ = ["load_state_dict", "load"] diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index ba1695d832122d..20abc2212f5e1a 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist from torch.distributed._state_dict_utils import _offload_state_dict_to_cpu - from torch.distributed.checkpoint._storage_utils import _storage_setup from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.logger import _dcp_method_logger diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py index bd786671c45260..dd46fe9246fd41 100644 --- a/torch/distributed/checkpoint/storage.py +++ b/torch/distributed/checkpoint/storage.py @@ -10,9 +10,9 @@ SavePlan, SavePlanner, ) - from torch.futures import Future + __all__ = ["WriteResult", "StorageWriter", "StorageReader"] diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 0efba34a551bc8..32649455163e65 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -25,6 +25,7 @@ ) from .metadata import MetadataIndex, STATE_DICT_TYPE + __all__ = ["find_tensor_shard", "find_state_dict_object"] T = TypeVar("T") diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 28937ca47b1a7d..50c69b8e962740 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -14,6 +14,7 @@ import time import traceback import warnings +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -21,16 +22,13 @@ import torch.distributed.elastic.rendezvous as rdzv import torch.distributed.elastic.utils.store as store_util -from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.events import Event, EventSource, record from torch.distributed.elastic.metrics import prof, put_metric -from torch.distributed.elastic.multiprocessing import ( - ProcessFailure, - SignalException, -) -from collections import defaultdict +from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException +from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError from torch.distributed.elastic.utils.logging import get_logger + __all__ = [ "WorkerSpec", "Worker", @@ -250,7 +248,16 @@ class WorkerGroup: group contains cross instance workers or not depends on the implementation of the agent. """ - __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state", "master_addr", "master_port"] + __slots__ = [ + "spec", + "workers", + "store", + "group_rank", + "group_world_size", + "state", + "master_addr", + "master_port", + ] def __init__(self, spec: WorkerSpec): self.spec = spec @@ -450,7 +457,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: raise NotImplementedError @abc.abstractmethod - def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: r"""Stop all workers in the given worker group. Implementors must deal with workers in all states defined by @@ -468,7 +477,9 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: raise NotImplementedError @abc.abstractmethod - def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None: + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: """Clean up any resources that were allocated during the agent's work. Args: @@ -499,7 +510,9 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: self._store = store with self.record_duration("ASSIGN_WORKER_RANKS"): - workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec) + workers = self._assign_worker_ranks( + store, group_rank, group_world_size, spec + ) worker_group.workers = workers worker_group.store = store worker_group.group_rank = group_rank @@ -532,8 +545,8 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None: "role_ranks": [worker.role_rank for worker in workers], "global_ranks": [worker.global_rank for worker in workers], "role_world_sizes": [worker.role_world_size for worker in workers], - "global_world_sizes": [worker.world_size for worker in workers] - } + "global_world_sizes": [worker.world_size for worker in workers], + }, ) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator @@ -612,9 +625,12 @@ def _assign_worker_ranks( store.multi_set(keys, values) # get will block until the data is available in the store. - base_global_rank, global_world_size, base_role_rank, role_world_size = json.loads( - store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}") - ) + ( + base_global_rank, + global_world_size, + base_role_rank, + role_world_size, + ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) workers = [] for local_rank in range(spec.local_world_size): @@ -733,7 +749,11 @@ def record_duration(self, state: str): finally: end_time = time.perf_counter() duration_ms = (end_time - start_time) * 1000 - record(self._construct_event(state=state, source=EventSource.AGENT, duration_ms=duration_ms)) + record( + self._construct_event( + state=state, source=EventSource.AGENT, duration_ms=duration_ms + ) + ) def _construct_event( self, @@ -844,7 +864,8 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.info( "[%s] worker group successfully finished." " Waiting %s seconds for other agents to finish.", - role, self._exit_barrier_timeout + role, + self._exit_barrier_timeout, ) self._exit_barrier() return run_result @@ -854,7 +875,10 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: "[%s] Worker group %s. " "%s/%s attempts left;" " will restart worker group", - role, state.name, self._remaining_restarts, spec.max_restarts + role, + state.name, + self._remaining_restarts, + spec.max_restarts, ) self._remaining_restarts -= 1 self._restart_workers(self._worker_group) @@ -871,11 +895,15 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: "[%s] Detected %s " "new nodes from group_rank=%s; " "will restart worker group", - role, num_nodes_waiting, group_rank + role, + num_nodes_waiting, + group_rank, ) self._restart_workers(self._worker_group) else: - raise Exception(f"[{role}] Worker group in {state.name} state") # noqa: TRY002 + raise Exception( # noqa: TRY002 + f"[{role}] Worker group in {state.name} state" + ) def _exit_barrier(self): """ @@ -889,7 +917,8 @@ def _exit_barrier(self): logger.info( "Local worker group finished (%s). " "Waiting %s seconds for other agents to finish", - self._worker_group.state, self._exit_barrier_timeout + self._worker_group.state, + self._exit_barrier_timeout, ) start = time.time() try: @@ -900,7 +929,8 @@ def _exit_barrier(self): barrier_timeout=self._exit_barrier_timeout, ) logger.info( - "Done waiting for other agents. Elapsed: %s seconds", time.time() - start + "Done waiting for other agents. Elapsed: %s seconds", + time.time() - start, ) except SignalException as e: logger.warning("Got termination signal: %s", e.sigval) @@ -908,5 +938,5 @@ def _exit_barrier(self): except Exception: logger.exception( "Error waiting on exit barrier. Elapsed: %s seconds", - time.time() - start + time.time() - start, ) diff --git a/torch/distributed/elastic/agent/server/health_check_server.py b/torch/distributed/elastic/agent/server/health_check_server.py index 00160730551520..d54915f7461685 100644 --- a/torch/distributed/elastic/agent/server/health_check_server.py +++ b/torch/distributed/elastic/agent/server/health_check_server.py @@ -10,6 +10,7 @@ from torch.distributed.elastic.utils.logging import get_logger + log = get_logger(__name__) __all__ = ["HealthCheckServer", "create_healthcheck_server"] diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 232f28234e6534..9423ef16f5c01d 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -12,14 +12,13 @@ import os import signal import socket -from string import Template import time import uuid +from string import Template from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING import torch.distributed.elastic.timer as timer from torch.distributed.elastic import events - from torch.distributed.elastic.agent.server.api import ( RunResult, SimpleElasticAgent, @@ -32,10 +31,15 @@ HealthCheckServer, ) from torch.distributed.elastic.metrics.api import prof -from torch.distributed.elastic.multiprocessing import PContext, start_processes, LogsSpecs +from torch.distributed.elastic.multiprocessing import ( + LogsSpecs, + PContext, + start_processes, +) from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger + if TYPE_CHECKING: from torch.distributed.elastic.events.api import EventMetadataValue @@ -52,6 +56,7 @@ TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" + class LocalElasticAgent(SimpleElasticAgent): """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. @@ -158,7 +163,6 @@ def __init__( self._logs_specs = logs_specs self._health_check_server: Optional[HealthCheckServer] = None - def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER watchdog_enabled = os.getenv(enable_watchdog_env_name) @@ -169,8 +173,10 @@ def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) if not envs: - logger.warning("Empty envs variables, using empty run_id for FileTimerServer") - run_id = '' + logger.warning( + "Empty envs variables, using empty run_id for FileTimerServer" + ) + run_id = "" else: run_id = envs[0]["TORCHELASTIC_RUN_ID"] self._worker_watchdog = timer.FileTimerServer( @@ -178,11 +184,15 @@ def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: run_id=run_id, max_interval=0.1, daemon=True, - log_event=self._log_watchdog_event) + log_event=self._log_watchdog_event, + ) self._worker_watchdog.start() logger.info("FileTimerServer started") else: - logger.info("Environment variable '%s' not found. Do not start FileTimerServer.", enable_watchdog_env_name) + logger.info( + "Environment variable '%s' not found. Do not start FileTimerServer.", + enable_watchdog_env_name, + ) # Propagate the watchdog file env to worker processes if watchdog_file_path is not None: for worker_env in envs.values(): @@ -202,7 +212,9 @@ def _setup_healthcheck(self) -> None: healthcheck_port, ) if self._worker_watchdog is None: - logger.info("FileTimerServer doesn't exist, using current time as dummy callback") + logger.info( + "FileTimerServer doesn't exist, using current time as dummy callback" + ) alive_callback = LocalElasticAgent._get_current_time_secs else: alive_callback = self._worker_watchdog.get_last_progress_time @@ -219,7 +231,6 @@ def _setup_healthcheck(self) -> None: healthcheck_port_env_name, ) - def _get_fq_hostname(self) -> str: return socket.getfqdn(socket.gethostname()) @@ -230,9 +241,7 @@ def _log_watchdog_event( ) -> None: wg = self._worker_group spec = wg.spec - md = { - "watchdog_event": name - } + md = {"watchdog_event": name} if request is not None: md["worker_pid"] = str(request.worker_pid) md["scope_id"] = request.scope_id @@ -264,7 +273,9 @@ def _log_watchdog_event( # pyre-fixme[56]: Pyre was not able to infer the type of the decorator # `torch.distributed.elastic.metrics.prof`. @prof - def _stop_workers(self, worker_group: WorkerGroup, is_restart: bool = False) -> None: + def _stop_workers( + self, worker_group: WorkerGroup, is_restart: bool = False + ) -> None: self._shutdown(is_restart=is_restart) # pyre-fixme[56]: Pyre was not able to infer the type of the decorator @@ -280,7 +291,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: args: Dict[int, Tuple] = {} envs: Dict[int, Dict[str, str]] = {} - log_line_prefixes: Optional[Dict[int, str]] = {} if self._log_line_prefix_template else None + log_line_prefixes: Optional[Dict[int, str]] = ( + {} if self._log_line_prefix_template else None + ) for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { @@ -306,12 +319,14 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] - if self._log_line_prefix_template: - log_line_prefix = Template(self._log_line_prefix_template).safe_substitute( + log_line_prefix = Template( + self._log_line_prefix_template + ).safe_substitute( role_name=spec.role, rank=worker.global_rank, - local_rank=local_rank,) + local_rank=local_rank, + ) log_line_prefixes[local_rank] = log_line_prefix envs[local_rank] = worker_env @@ -336,7 +351,9 @@ def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: return self._pcontext.pids() - def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False) -> None: + def _shutdown( + self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False + ) -> None: if self._worker_watchdog is not None: self._worker_watchdog.stop() self._worker_watchdog = None @@ -360,7 +377,9 @@ def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: logger.error( "[%s] worker pids do not match process_context pids." " Expected: %s, actual: %s", - role, worker_pids, pc_pids + role, + worker_pids, + pc_pids, ) return RunResult(state=WorkerState.UNKNOWN) diff --git a/torch/distributed/elastic/control_plane.py b/torch/distributed/elastic/control_plane.py index 160383637865b6..e778c086838477 100644 --- a/torch/distributed/elastic/control_plane.py +++ b/torch/distributed/elastic/control_plane.py @@ -4,6 +4,7 @@ from torch.distributed.elastic.multiprocessing.errors import record + __all__ = [ "worker_main", ] diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index 9f6e1733518afa..5e4a0a6f236918 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -24,7 +24,6 @@ import os import socket import traceback -from enum import Enum from typing import Dict, Optional from torch.distributed.elastic.events.handlers import get_logging_handler @@ -37,8 +36,10 @@ RdzvEvent, ) + _events_loggers: Dict[str, logging.Logger] = {} + def _get_or_create_logger(destination: str = "null") -> logging.Logger: """ Construct python logger based on the destination type or extends if provided. @@ -71,6 +72,7 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger: def record(event: Event, destination: str = "null") -> None: _get_or_create_logger(destination).info(event.serialize()) + def record_rdzv_event(event: RdzvEvent) -> None: _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 082499b3af6388..c610cfd4cb3540 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -10,9 +10,10 @@ import json from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Dict, Union, Optional +from typing import Dict, Optional, Union -__all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent'] + +__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] EventMetadataValue = Union[str, int, float, bool, None] diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index d8bea0b3c07917..4b72dcd7c6020d 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -139,14 +139,14 @@ def emit(self, metric_data): from typing import Optional from .api import ( # noqa: F401 + configure, ConsoleMetricHandler, + get_elapsed_time_ms, + getStream, MetricData, MetricHandler, MetricsConfig, NullMetricHandler, - configure, - get_elapsed_time_ms, - getStream, prof, profile, publish_metric, diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 7b6d8295ef0511..2c07d3b5c47bbd 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -14,9 +14,22 @@ from typing import Dict, Optional from typing_extensions import deprecated -__all__ = ['MetricsConfig', 'MetricHandler', 'ConsoleMetricHandler', 'NullMetricHandler', 'MetricStream', - 'configure', 'getStream', 'prof', 'profile', 'put_metric', 'publish_metric', 'get_elapsed_time_ms', - 'MetricData'] + +__all__ = [ + "MetricsConfig", + "MetricHandler", + "ConsoleMetricHandler", + "NullMetricHandler", + "MetricStream", + "configure", + "getStream", + "prof", + "profile", + "put_metric", + "publish_metric", + "get_elapsed_time_ms", + "MetricData", +] MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) @@ -150,6 +163,7 @@ def profile(group=None): @metrics.profile("my_metric_group") def some_function(): """ + def wrap(func): @wraps(func) def wrapper(*args, **kwargs): diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index 4e26ab1744a98c..21cb5e47d44192 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -62,8 +62,7 @@ def trainer(a, b, c): implementations of the parent :class:`api.PContext` class. """ -import os -from typing import Callable, Dict, Optional, Tuple, Union, Set +from typing import Callable, Dict, Optional, Tuple, Union from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 _validate_full_rank, @@ -81,6 +80,7 @@ def trainer(a, b, c): ) from torch.distributed.elastic.utils.logging import get_logger + __all__ = [ "start_processes", "MultiprocessContext", diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 5d294a7d080217..8968dbdc8e6db2 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -17,13 +17,13 @@ import sys import tempfile import time +from abc import ABC, abstractmethod from contextlib import nullcontext from dataclasses import dataclass, field from enum import IntFlag from multiprocessing import synchronize from types import FrameType from typing import Any, Callable, Dict, Optional, Set, Tuple, Union -from abc import ABC, abstractmethod import torch.multiprocessing as mp from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record @@ -31,10 +31,13 @@ redirect_stderr, redirect_stdout, ) - -from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler +from torch.distributed.elastic.multiprocessing.subprocess_handler import ( + get_subprocess_handler, + SubprocessHandler, +) from torch.distributed.elastic.multiprocessing.tail_log import TailLog + IS_WINDOWS = sys.platform == "win32" IS_MACOS = sys.platform == "darwin" @@ -55,6 +58,7 @@ "LogsSpecs", ] + class SignalException(Exception): """ Exception is raised inside the torchelastic agent process by the termination handler @@ -178,6 +182,7 @@ class LogsDest: """ For each log type, holds mapping of local rank ids to file paths. """ + stdouts: Dict[int, str] = field(default_factory=dict) stderrs: Dict[int, str] = field(default_factory=dict) tee_stdouts: Dict[int, str] = field(default_factory=dict) @@ -215,7 +220,10 @@ def __init__( self._local_ranks_filter = local_ranks_filter @abstractmethod - def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: """ Given the environment variables, builds destination of log files for each of the local ranks. @@ -229,6 +237,7 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: def root_log_dir(self) -> str: pass + class DefaultLogsSpecs(LogsSpecs): """ Default LogsSpecs implementation: @@ -236,6 +245,7 @@ class DefaultLogsSpecs(LogsSpecs): - `log_dir` will be created if it doesn't exist - Generates nested folders for each attempt and rank. """ + def __init__( self, log_dir: Optional[str] = None, @@ -266,7 +276,10 @@ def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): logger.info("log directory set to: %s", dir) return dir - def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: + def reify( + self, + envs: Dict[int, Dict[str, str]], + ) -> LogsDest: """ Uses following scheme to build log destination paths: @@ -279,7 +292,9 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: if nprocs > 0: global_env = envs[0] else: - logger.warning("Empty envs map provided when defining logging destinations.") + logger.warning( + "Empty envs map provided when defining logging destinations." + ) # Keys are always defined, but values can be missing in unit tests run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") @@ -321,7 +336,6 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: error_files = {} for local_rank in range(nprocs): - if attempt_log_dir == os.devnull: tee_stdouts[local_rank] = os.devnull tee_stderrs[local_rank] = os.devnull @@ -343,7 +357,10 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: if t & Std.ERR == Std.ERR: tee_stderrs[local_rank] = stderrs[local_rank] - if self._local_ranks_filter and local_rank not in self._local_ranks_filter: + if ( + self._local_ranks_filter + and local_rank not in self._local_ranks_filter + ): # If stream is tee'd, only write to file, but don't tail if local_rank in tee_stdouts: tee_stdouts.pop(local_rank, None) @@ -358,7 +375,9 @@ def reify(self, envs: Dict[int, Dict[str, str]],) -> LogsDest: error_file = os.path.join(clogdir, "error.json") error_files[local_rank] = error_file - logger.info("Setting worker%s reply file to: %s", local_rank, error_file) + logger.info( + "Setting worker%s reply file to: %s", local_rank, error_file + ) envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) @@ -423,7 +442,6 @@ def __init__( envs: Dict[int, Dict[str, str]], logs_specs: LogsSpecs, log_line_prefixes: Optional[Dict[int, str]] = None, - ): self.name = name # validate that all mappings have the same number of keys and @@ -444,8 +462,12 @@ def __init__( self.error_files = logs_dest.error_files self.nprocs = nprocs - self._stdout_tail = TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes) - self._stderr_tail = TailLog(name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes) + self._stdout_tail = TailLog( + name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes + ) + self._stderr_tail = TailLog( + name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes + ) def start(self) -> None: """Start processes using parameters defined in the constructor.""" @@ -678,7 +700,9 @@ def _poll(self) -> Optional[RunProcsResult]: # But the child process might still have not exited. Wait for them. # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. while not self._pc.join(): - logger.debug("entrypoint fn finished, waiting for all child procs to exit...") + logger.debug( + "entrypoint fn finished, waiting for all child procs to exit..." + ) _validate_full_rank( self._return_values, self.nprocs, "return_value queue" @@ -704,8 +728,10 @@ def _poll(self) -> Optional[RunProcsResult]: " local_rank: %s (pid: %s)" " of fn: %s (start_method: %s)", failed_proc.exitcode, - failed_local_rank, e.pid, - fn_name, self.start_method, + failed_local_rank, + e.pid, + fn_name, + self.start_method, ) self.close() @@ -731,7 +757,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: return for proc in self._pc.processes: if proc.is_alive(): - logger.warning("Closing process %s via signal %s", proc.pid, death_sig.name) + logger.warning( + "Closing process %s via signal %s", proc.pid, death_sig.name + ) try: os.kill(proc.pid, death_sig) except ProcessLookupError: @@ -748,7 +776,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: if proc.is_alive(): logger.warning( "Unable to shutdown process %s via %s, forcefully exiting via %s", - proc.pid, death_sig, _get_kill_signal() + proc.pid, + death_sig, + _get_kill_signal(), ) try: os.kill(proc.pid, _get_kill_signal()) @@ -758,6 +788,7 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: pass proc.join() + class SubprocessContext(PContext): """``PContext`` holding worker processes invoked as a binary.""" @@ -769,7 +800,6 @@ def __init__( envs: Dict[int, Dict[str, str]], logs_specs: LogsSpecs, log_line_prefixes: Optional[Dict[int, str]] = None, - ): super().__init__( name, @@ -834,7 +864,10 @@ def _poll(self) -> Optional[RunProcsResult]: "failed (exitcode: %s)" " local_rank: %s (pid: %s)" " of binary: %s", - first_failure.exitcode, first_failure.local_rank, first_failure.pid, self.entrypoint + first_failure.exitcode, + first_failure.local_rank, + first_failure.pid, + self.entrypoint, ) else: # Populate return with dummy values. This provides consistency with MultiprocessingHandler @@ -856,7 +889,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: for handler in self.subprocess_handlers.values(): if handler.proc.poll() is None: logger.warning( - "Sending process %s closing signal %s", handler.proc.pid, death_sig.name + "Sending process %s closing signal %s", + handler.proc.pid, + death_sig.name, ) handler.close(death_sig=death_sig) end = time.monotonic() + timeout @@ -874,7 +909,9 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: if handler.proc.poll() is None: logger.warning( "Unable to shutdown process %s via %s, forcefully exiting via %s", - handler.proc.pid, death_sig, _get_kill_signal() + handler.proc.pid, + death_sig, + _get_kill_signal(), ) handler.close(death_sig=_get_kill_signal()) handler.proc.wait() diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index d63c283b4c35ec..2f5ed2d1ab0b81 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -66,7 +66,14 @@ from .error_handler import ErrorHandler # noqa: F401 from .handlers import get_error_handler # noqa: F401 -__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"] + +__all__ = [ + "ProcessFailure", + "ChildFailedError", + "record", + "ErrorHandler", + "get_error_handler", +] logger = get_logger(__name__) @@ -113,7 +120,8 @@ def __post_init__(self): with open(self.error_file) as fp: self.error_file_data = json.load(fp) logger.debug( - "User process failed with error data: %s", json.dumps(self.error_file_data, indent=2) + "User process failed with error data: %s", + json.dumps(self.error_file_data, indent=2), ) self.message, self.timestamp = self._get_error_data( self.error_file_data @@ -264,7 +272,6 @@ def format_msg(self, boarder_delim="=", section_delim="-"): def _format_failure( self, idx: int, rank: int, failure: ProcessFailure ) -> Tuple[str, int]: - # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) # or a dict (json) of the form # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} @@ -363,7 +370,7 @@ def wrapper(*args, **kwargs): "local_rank %s FAILED with no error file." " Decorate your entrypoint fn with @record for traceback info." " See: https://pytorch.org/docs/stable/elastic/errors.html", - rank + rank, ) ) raise diff --git a/torch/distributed/elastic/multiprocessing/errors/error_handler.py b/torch/distributed/elastic/multiprocessing/errors/error_handler.py index 34d6229dda3b67..89e7fffdd5c7dc 100644 --- a/torch/distributed/elastic/multiprocessing/errors/error_handler.py +++ b/torch/distributed/elastic/multiprocessing/errors/error_handler.py @@ -15,7 +15,8 @@ import warnings from typing import Any, Dict, Optional -__all__ = ['ErrorHandler'] + +__all__ = ["ErrorHandler"] logger = logging.getLogger(__name__) @@ -93,13 +94,14 @@ def override_error_code_in_rootcause_data( logger.warning( "child error file (%s) does not have field `message`. \n" "cannot override error code: %s", - rootcause_error_file, error_code + rootcause_error_file, + error_code, ) elif isinstance(rootcause_error["message"], str): logger.warning( "child error file (%s) has a new message format. \n" "skipping error code override", - rootcause_error_file + rootcause_error_file, ) else: rootcause_error["message"]["errorCode"] = error_code @@ -111,11 +113,13 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): # Override error code since the child process cannot capture the error code if it # is terminated by signals like SIGSEGV. if error_code: - self.override_error_code_in_rootcause_data(rootcause_error_file, rootcause_error, error_code) + self.override_error_code_in_rootcause_data( + rootcause_error_file, rootcause_error, error_code + ) logger.debug( - "child error file (%s) contents:\n" - "%s", - rootcause_error_file, json.dumps(rootcause_error, indent=2) + "child error file (%s) contents:\n" "%s", + rootcause_error_file, + json.dumps(rootcause_error, indent=2), ) my_error_file = self._get_error_file_path() @@ -135,7 +139,8 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): logger.info("dumped error file to parent's %s", my_error_file) else: logger.error( - "no error file defined for parent, to copy child error file (%s)", rootcause_error_file + "no error file defined for parent, to copy child error file (%s)", + rootcause_error_file, ) def _rm(self, my_error_file): @@ -148,13 +153,14 @@ def _rm(self, my_error_file): "%s already exists" " and will be overwritten." " Original contents:\n%s", - my_error_file, original + my_error_file, + original, ) except json.decoder.JSONDecodeError: logger.warning( "%s already exists" " and will be overwritten." " Unable to load original contents:\n", - my_error_file + my_error_file, ) os.remove(my_error_file) diff --git a/torch/distributed/elastic/multiprocessing/errors/handlers.py b/torch/distributed/elastic/multiprocessing/errors/handlers.py index 09b2aca55f16ae..b8a78e73702fd1 100644 --- a/torch/distributed/elastic/multiprocessing/errors/handlers.py +++ b/torch/distributed/elastic/multiprocessing/errors/handlers.py @@ -11,7 +11,9 @@ from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler -__all__ = ['get_error_handler'] + +__all__ = ["get_error_handler"] + def get_error_handler(): return ErrorHandler() diff --git a/torch/distributed/elastic/multiprocessing/redirects.py b/torch/distributed/elastic/multiprocessing/redirects.py index 8ad3e2edf1c15f..057013fbb9e5b8 100644 --- a/torch/distributed/elastic/multiprocessing/redirects.py +++ b/torch/distributed/elastic/multiprocessing/redirects.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from functools import partial + IS_WINDOWS = sys.platform == "win32" IS_MACOS = sys.platform == "darwin" diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py index 4c335964c7322a..f56d423ce080fd 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py @@ -12,4 +12,5 @@ SubprocessHandler, ) + __all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py index e122f89a94f777..2660be5af399a4 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py @@ -12,6 +12,7 @@ SubprocessHandler, ) + __all__ = ["get_subprocess_handler"] diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 7cacf986857500..c548d092092265 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -9,9 +9,9 @@ import signal import subprocess import sys - from typing import Any, Dict, Optional, Tuple + __all__ = ["SubprocessHandler"] IS_WINDOWS = sys.platform == "win32" diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 804e2e5a6323d6..2c814ffb7be998 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -14,6 +14,7 @@ from threading import Event from typing import Dict, List, Optional, TextIO, TYPE_CHECKING + if TYPE_CHECKING: from concurrent.futures._base import Future @@ -25,7 +26,6 @@ def tail_logfile( header: str, file: str, dst: TextIO, finished: Event, interval_sec: float ): - while not os.path.exists(file): if finished.is_set(): return @@ -143,8 +143,10 @@ def stop(self) -> None: except Exception as e: logger.error( "error in log tailor for %s%s. %s: %s", - self._name, local_rank, - e.__class__.__qualname__, e, + self._name, + local_rank, + e.__class__.__qualname__, + e, ) if self._threadpool: diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index f6ec6a6eb62f6c..62a31adab27b01 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -128,8 +128,8 @@ class that implements the rendezvous mechanism described above. It is a backend- ) """ - from .api import ( + rendezvous_handler_registry, RendezvousClosedError, RendezvousConnectionError, RendezvousError, @@ -142,9 +142,7 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStateError, RendezvousStoreInfo, RendezvousTimeoutError, - rendezvous_handler_registry, ) - from .registry import _register_default_handlers diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 7ddcd7c70b9af9..9cde6758981abf 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -6,7 +6,6 @@ # LICENSE file in the root directory of this source tree. import socket - from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Callable, ClassVar, Dict, Optional @@ -51,15 +50,18 @@ class RendezvousConnectionError(RendezvousError): class RendezvousStateError(RendezvousError): """Raised when the state of a rendezvous is corrupt.""" + class RendezvousGracefulExitError(RendezvousError): """Raised when node wasn't not included in rendezvous and gracefully exits. Exception is a mechanism to exit the stack, however does not mean a failure. """ + @dataclass class RendezvousStoreInfo: """Store address and port that can be used to bootstrap trainer distributed comms""" + MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" master_addr: str @@ -79,13 +81,22 @@ def build(rank: int, store: Store) -> "RendezvousStoreInfo": store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") - port = int(store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")) + port = int( + store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") + ) return RendezvousStoreInfo(master_addr=addr, master_port=port) class RendezvousInfo: """Holds the information about the rendezvous.""" - def __init__(self, store: Store, rank: int, world_size: int, bootstrap_store_info: RendezvousStoreInfo): + + def __init__( + self, + store: Store, + rank: int, + world_size: int, + bootstrap_store_info: RendezvousStoreInfo, + ): self._store = store self._rank = rank self._world_size = world_size diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 4c1c687411ef2d..26c3153d9785b6 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -11,13 +11,10 @@ import tempfile from base64 import b64decode, b64encode from datetime import timedelta -from typing import Any, Optional, Tuple, cast +from typing import Any, cast, Optional, Tuple from torch.distributed import FileStore, Store, TCPStore -from torch.distributed.elastic.events import ( - NodeState, - construct_and_record_rdzv_event, -) +from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState from .api import ( RendezvousConnectionError, @@ -28,6 +25,7 @@ from .dynamic_rendezvous import RendezvousBackend, Token from .utils import _matches_machine_hostname, parse_rendezvous_endpoint + logger = logging.getLogger(__name__) @@ -96,7 +94,9 @@ def set_state( else: token = self._NULL_SENTINEL - base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str) + base64_state: bytes = self._call_store( + "compare_set", self._key, token, base64_state_str + ) state_token_pair = self._decode_state(base64_state) if state_token_pair is None: @@ -256,7 +256,9 @@ def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, elif store_type == "tcp": store = _create_tcp_store(params) else: - raise ValueError("Invalid store type given. Currently only supports file and tcp.") + raise ValueError( + "Invalid store type given. Currently only supports file and tcp." + ) backend = C10dRendezvousBackend(store, params.run_id) diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index ad45077d3943ea..31627cf0a0b27c 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -20,7 +20,6 @@ from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple import torch.distributed as dist - from torch.distributed import Store from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState @@ -37,6 +36,7 @@ ) from .utils import _delay, _PeriodicTimer + __all__ = [ "RendezvousBackend", "RendezvousTimeout", @@ -57,6 +57,7 @@ def get_method_name(depth=2): Token = Any """Represent an opaque fencing token used by the rendezvous backend.""" + class RendezvousBackend(ABC): """Represent a backend that holds the rendezvous state.""" @@ -157,7 +158,9 @@ def __init__( close: Optional[timedelta] = None, heartbeat: Optional[timedelta] = None, ) -> None: - self._set_timeouts(join=join, last_call=last_call, close=close, heartbeat=heartbeat) + self._set_timeouts( + join=join, last_call=last_call, close=close, heartbeat=heartbeat + ) @property def join(self) -> timedelta: @@ -311,7 +314,9 @@ def __init__(self) -> None: self.last_heartbeats = {} -def _remove_participant_epilogue(state: _RendezvousState, settings: RendezvousSettings) -> None: +def _remove_participant_epilogue( + state: _RendezvousState, settings: RendezvousSettings +) -> None: if state.complete: # If we do not have any participants left, move to the next round. if not state.participants: @@ -424,7 +429,9 @@ def sync(self) -> Optional[bool]: if self._cache_duration > 0: # Avoid overloading the backend if we are asked to retrieve the # state repeatedly. Try to serve the cached state. - if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0): + if self._last_sync_time >= max( + time.monotonic() - self._cache_duration, 0 + ): return None get_response = self._backend.get_state() @@ -917,14 +924,19 @@ def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: if ctx.node not in state.wait_list: return _Action.ADD_TO_WAIT_LIST elif len(state.participants) >= ctx.settings.max_nodes: - if ctx.node not in state.redundancy_list and ctx.node not in state.wait_list: + if ( + ctx.node not in state.redundancy_list + and ctx.node not in state.wait_list + ): return _Action.ADD_TO_REDUNDANCY_LIST elif is_participant: # If the rendezvous has enough number of participants including us, # check whether we have passed the rendezvous deadline. If yes, # complete it. - if len(state.participants) >= ctx.settings.min_nodes and \ - len(state.participants) <= ctx.settings.max_nodes: + if ( + len(state.participants) >= ctx.settings.min_nodes + and len(state.participants) <= ctx.settings.max_nodes + ): if cast(datetime, state.deadline) < datetime.utcnow(): msg = ( f"The node '{ctx.node}' marking the rendezvous complete, " @@ -1143,10 +1155,7 @@ def next_rendezvous(self) -> RendezvousInfo: deadline = self._get_deadline(self._settings.timeout.join) self._op_executor.run(exit_op, deadline) - self._op_executor.run( - join_op, - deadline, - self._get_deadline) + self._op_executor.run(join_op, deadline, self._get_deadline) self._start_heartbeats() @@ -1182,7 +1191,9 @@ def next_rendezvous(self) -> RendezvousInfo: if isinstance(self._store, dist.TCPStore): addr = self._store.host port = self._store.port - self._bootstrap_store_info = RendezvousStoreInfo(master_addr=addr, master_port=port) + self._bootstrap_store_info = RendezvousStoreInfo( + master_addr=addr, master_port=port + ) if rank == 0: self._shared_tcp_store_server = self._store else: @@ -1190,7 +1201,9 @@ def next_rendezvous(self) -> RendezvousInfo: # bootstrapping info across ranks self._bootstrap_store_info = RendezvousStoreInfo.build(rank, store) if rank == 0: - self._shared_tcp_store_server = self._create_tcp_store_server(self._bootstrap_store_info) + self._shared_tcp_store_server = self._create_tcp_store_server( + self._bootstrap_store_info + ) assert self._bootstrap_store_info is not None if rank == 0: @@ -1321,7 +1334,9 @@ def _start_heartbeats(self) -> None: self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) ) - self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}") + self._keep_alive_timer.set_name( + f"RendezvousKeepAliveTimer_{self._this_node.local_id}" + ) self._keep_alive_timer.start() @@ -1337,7 +1352,9 @@ def _get_world(self) -> Tuple[int, int]: return state.participants[self._this_node], len(state.participants) def _wrap_store(self, store: Store) -> Store: - key_prefix = f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + key_prefix = ( + f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" + ) return dist.PrefixStore(key_prefix, store) diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 1a371b74275a19..fe6170ede0159e 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -15,6 +15,7 @@ from typing import Optional import etcd # type: ignore[import] + from torch.distributed.elastic.rendezvous import ( RendezvousClosedError, RendezvousError, @@ -25,15 +26,16 @@ RendezvousTimeoutError, ) +from .etcd_store import cas_delay, EtcdStore from .utils import parse_rendezvous_endpoint -from .etcd_store import EtcdStore, cas_delay + __all__ = [ "EtcdRendezvousRetryableFailure", "EtcdRendezvousRetryImmediately", "EtcdRendezvousHandler", "EtcdRendezvous", - "create_rdzv_handler" + "create_rdzv_handler", ] _log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") @@ -373,7 +375,9 @@ def join_phase(self, expected_version): state = json.loads(active_version.value) logger.info( "Joined rendezvous version %s as rank %s. Full state: %s", - state["version"], this_rank, state + state["version"], + this_rank, + state, ) # If this worker was first to reach num_min_workers requirement, @@ -418,7 +422,8 @@ def confirm_phase(self, expected_version, this_rank): logger.info( "Rendezvous version %s is complete. Final state: %s", - state["version"], state + state["version"], + state, ) # Rendezvous version number; our rank in it; world size @@ -436,12 +441,13 @@ def handle_existing_rendezvous(self, expected_version): # 2. if keep alives are missing, destroy it and bail out. active_state = self.announce_self_waiting(expected_version) logger.info( - "Added self to waiting list. Rendezvous full state: %s", - active_state.value + "Added self to waiting list. Rendezvous full state: %s", active_state.value ) self.wait_for_rendezvous_to_free(expected_version) - logger.info("Previously existing rendezvous state changed. Will re-try joining.") + logger.info( + "Previously existing rendezvous state changed. Will re-try joining." + ) def try_create_rendezvous(self): """ @@ -688,8 +694,7 @@ def wait_for_rendezvous_to_free(self, expected_version): # rendezvous version as dead (but only if it hadn't changed) logger.info("Keep-alive key %s is not renewed.", key) logger.info( - "Rendezvous version %s is incomplete. ", - expected_version + "Rendezvous version %s is incomplete. ", expected_version ) logger.info("Attempting to destroy it.") @@ -703,7 +708,7 @@ def wait_for_rendezvous_to_free(self, expected_version): logger.info( "Destroyed rendezvous version %s successfully.", - expected_version + expected_version, ) # We can return (and retry) immediately @@ -770,7 +775,9 @@ def handle_join_last_call(self, expected_version, deadline): # We successfully made this rendezvous frozen. return except etcd.EtcdCompareFailed: - logger.info("Join last-call transition CAS unsuccessful. Will retry") + logger.info( + "Join last-call transition CAS unsuccessful. Will retry" + ) cas_delay() active_version, state = self.get_rdzv_state() continue @@ -1051,6 +1058,8 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: num_min_workers=params.min_nodes, num_max_workers=params.max_nodes, timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), - last_call_timeout=params.get_as_int("last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT), + last_call_timeout=params.get_as_int( + "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT + ), ) return EtcdRendezvousHandler(rdzv_impl=rdzv) diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index c9d60abdc2369e..75ae347293c8fb 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -7,17 +7,18 @@ import binascii from base64 import b64decode, b64encode -from typing import Optional, Tuple, cast +from typing import cast, Optional, Tuple import urllib3.exceptions # type: ignore[import] -from etcd import Client as EtcdClient # type: ignore[import] -from etcd import ( +from etcd import ( # type: ignore[import] + Client as EtcdClient, EtcdAlreadyExist, EtcdCompareFailed, EtcdException, EtcdKeyNotFound, EtcdResult, ) + from torch.distributed import Store from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError @@ -207,7 +208,9 @@ def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, """ client = _create_etcd_client(params) - backend = EtcdRendezvousBackend(client, params.run_id, key_prefix="/torch/elastic/rendezvous") + backend = EtcdRendezvousBackend( + client, params.run_id, key_prefix="/torch/elastic/rendezvous" + ) store = EtcdStore(client, "/torch/elastic/store") diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index 891858534c565b..8af8c01c028ae0 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -17,6 +17,7 @@ import time from typing import Optional, TextIO, Union + try: import etcd # type: ignore[import] except ModuleNotFoundError: diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index 60559647568649..4fa1bef06857d4 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -178,7 +178,9 @@ def _try_wait_get(self, b64_keys, override_timeout=None): # Read whole directory (of keys), filter only the ones waited for all_nodes = self.client.get(key=self.prefix) req_nodes = { - node.key: node.value for node in all_nodes.children if node.key in b64_keys + node.key: node.value + for node in all_nodes.children + if node.key in b64_keys } if len(req_nodes) == len(b64_keys): diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index eaa5bcfd80e247..1a91d0a8ff7946 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,11 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .api import RendezvousHandler, RendezvousParameters -from .api import rendezvous_handler_registry as handler_registry +from .api import ( + rendezvous_handler_registry as handler_registry, + RendezvousHandler, + RendezvousParameters, +) from .dynamic_rendezvous import create_handler -__all__ = ['get_rendezvous_handler'] + +__all__ = ["get_rendezvous_handler"] + def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: from . import static_tcp_rendezvous diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index ace82d0a222674..5d2679d9fb4a0c 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -11,15 +11,16 @@ import logging from typing import cast, Optional -from torch.distributed import Store, TCPStore, PrefixStore +from torch.distributed import PrefixStore, Store, TCPStore from torch.distributed.elastic.rendezvous import ( - RendezvousInfo, RendezvousHandler, + RendezvousInfo, + RendezvousParameters, RendezvousStoreInfo, - RendezvousParameters ) from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint + __all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] logger = logging.getLogger(__name__) diff --git a/torch/distributed/elastic/rendezvous/utils.py b/torch/distributed/elastic/rendezvous/utils.py index 8419051d29f82b..a93c9c39e58635 100644 --- a/torch/distributed/elastic/rendezvous/utils.py +++ b/torch/distributed/elastic/rendezvous/utils.py @@ -15,7 +15,9 @@ from threading import Event, Thread from typing import Any, Callable, Dict, Optional, Tuple, Union -__all__ = ['parse_rendezvous_endpoint'] + +__all__ = ["parse_rendezvous_endpoint"] + def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: """Extract key-value pairs from a rendezvous configuration string. @@ -62,7 +64,9 @@ def _try_parse_port(port_str: str) -> Optional[int]: return None -def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]: +def parse_rendezvous_endpoint( + endpoint: Optional[str], default_port: int +) -> Tuple[str, int]: """Extract the hostname and the port number from a rendezvous endpoint. Args: @@ -92,7 +96,7 @@ def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tup if len(rest) == 1: port = _try_parse_port(rest[0]) - if port is None or port >= 2 ** 16: + if port is None or port >= 2**16: raise ValueError( f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " "between 0 and 65536." @@ -135,10 +139,7 @@ def _matches_machine_hostname(host: str) -> bool: except (ValueError, socket.gaierror) as _: host_addr_list = [] - host_ip_list = [ - host_addr_info[4][0] - for host_addr_info in host_addr_list - ] + host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] this_host = socket.gethostname() if host == this_host: @@ -246,7 +247,10 @@ def start(self) -> None: raise RuntimeError("The timer has already started.") self._thread = Thread( - target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True + target=self._run, + name=self._name or "PeriodicTimer", + args=(self._ctx,), + daemon=True, ) # We avoid using a regular finalizer (a.k.a. __del__) for stopping the diff --git a/torch/distributed/elastic/timer/__init__.py b/torch/distributed/elastic/timer/__init__.py index ea4b2a46c4231d..b9c2ea349cc67f 100644 --- a/torch/distributed/elastic/timer/__init__.py +++ b/torch/distributed/elastic/timer/__init__.py @@ -39,6 +39,16 @@ def trainer_func(message_queue): complete, then the worker process is killed and the agent retries the worker group. """ -from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401 +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 -from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401 diff --git a/torch/distributed/elastic/timer/api.py b/torch/distributed/elastic/timer/api.py index 77fcaaceed4f2e..fe8d440b1afb8b 100644 --- a/torch/distributed/elastic/timer/api.py +++ b/torch/distributed/elastic/timer/api.py @@ -12,10 +12,19 @@ from inspect import getframeinfo, stack from typing import Any, Dict, List, Optional, Set -__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires'] + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] logger = logging.getLogger(__name__) + class TimerRequest: """ Data object representing a countdown timer acquisition and release @@ -192,9 +201,9 @@ def _run_watchdog(self): reaped_worker_ids = set() for worker_id, expired_timers in self.get_expired_timers(now).items(): logger.info( - "Reaping worker_id=[%s]." - " Expired timers: %s", - worker_id, self._get_scopes(expired_timers) + "Reaping worker_id=[%s]." " Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), ) if self._reap_worker_no_throw(worker_id): logger.info("Successfully reaped worker=[%s]", worker_id) @@ -210,10 +219,10 @@ def _get_scopes(self, timer_requests): def start(self) -> None: logger.info( - "Starting %s..." - " max_interval=%s," - " daemon=%s", - type(self).__name__, self._max_interval, self._daemon + "Starting %s..." " max_interval=%s," " daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, ) self._watchdog_thread = threading.Thread( target=self._watchdog_loop, daemon=self._daemon diff --git a/torch/distributed/elastic/timer/debug_info_logging.py b/torch/distributed/elastic/timer/debug_info_logging.py index 55a1a9e9bcdf76..3dce543220d83d 100644 --- a/torch/distributed/elastic/timer/debug_info_logging.py +++ b/torch/distributed/elastic/timer/debug_info_logging.py @@ -11,6 +11,7 @@ from torch.distributed.elastic.utils.logging import get_logger + logger = get_logger(__name__) __all__ = ["log_debug_info_for_expired_timers"] diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index fce46f053a7e79..74da756d58c99a 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -16,13 +16,17 @@ from typing import Callable, Dict, List, Optional, Set, Tuple from torch.distributed.elastic.timer.api import TimerClient, TimerRequest -from torch.distributed.elastic.timer.debug_info_logging import log_debug_info_for_expired_timers +from torch.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) from torch.distributed.elastic.utils.logging import get_logger + __all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] logger = get_logger(__name__) + class FileTimerRequest(TimerRequest): """ Data object representing a countdown timer acquisition and release @@ -35,7 +39,9 @@ class FileTimerRequest(TimerRequest): __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] - def __init__(self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0) -> None: + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: self.version = 1 self.worker_pid = worker_pid self.scope_id = scope_id @@ -60,7 +66,7 @@ def to_json(self) -> str: "pid": self.worker_pid, "scope_id": self.scope_id, "expiration_time": self.expiration_time, - "signal": self.signal + "signal": self.signal, }, ) @@ -83,8 +89,12 @@ class FileTimerClient(TimerClient): signal: signal, the signal to use to kill the process. Using a negative or zero signal will not kill the process. """ - def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else - signal.CTRL_C_EVENT)) -> None: # type: ignore[attr-defined] + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: super().__init__() self._file_path = file_path self.signal = signal @@ -103,7 +113,9 @@ def _send_request(self, request: FileTimerRequest) -> None: # be raised if the server is not there. file = self._open_non_blocking() if file is None: - raise BrokenPipeError("Could not send the FileTimerRequest because FileTimerServer is not available.") + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) with file: json_request = request.to_json() # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. @@ -120,17 +132,14 @@ def acquire(self, scope_id: str, expiration_time: float) -> None: worker_pid=os.getpid(), scope_id=scope_id, expiration_time=expiration_time, - signal=self.signal + signal=self.signal, ), ) def release(self, scope_id: str) -> None: self._send_request( request=FileTimerRequest( - worker_pid=os.getpid(), - scope_id=scope_id, - expiration_time=-1, - signal=0 + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 ), ) @@ -161,7 +170,7 @@ def __init__( run_id: str, max_interval: float = 10, daemon: bool = True, - log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None + log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, ) -> None: self._file_path = file_path self._run_id = run_id @@ -177,18 +186,21 @@ def __init__( self._request_count = 0 # For test only. Process all requests and stop the server. self._run_once = False - self._log_event = log_event if log_event is not None else lambda name, request: None + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) self._last_progress_time = int(time.time()) - def start(self) -> None: logger.info( - "Starting %s..." - " max_interval=%s," - " daemon=%s", - type(self).__name__, self._max_interval, self._daemon + "Starting %s..." " max_interval=%s," " daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon ) - self._watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=self._daemon) logger.info("Starting watchdog thread...") self._watchdog_thread.start() self._log_event("watchdog started", None) @@ -255,11 +267,18 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: all_expired_timers = self.get_expired_timers(now) log_debug_info_for_expired_timers( self._run_id, - {pid: self._get_scopes(expired_timers) for pid, expired_timers in all_expired_timers.items()}, + { + pid: self._get_scopes(expired_timers) + for pid, expired_timers in all_expired_timers.items() + }, ) for worker_pid, expired_timers in all_expired_timers.items(): - logger.info("Reaping worker_pid=[%s]. Expired timers: %s", worker_pid, self._get_scopes(expired_timers)) + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) reaped_worker_pids.add(worker_pid) # In case we have multiple expired timers, we find the first timer # with a valid signal (>0) in the expiration time order. @@ -273,19 +292,28 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: expired_timer = timer break if signal <= 0: - logger.info("No signal specified with worker=[%s]. Do not reap it.", worker_pid) + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) continue if self._reap_worker(worker_pid, signal): - logger.info("Successfully reaped worker=[%s] with signal=%s", worker_pid, signal) + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) self._log_event("kill worker process", expired_timer) else: - logger.error("Error reaping worker=[%s]. Will retry on next watchdog.", worker_pid) + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) self.clear_timers(reaped_worker_pids) def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: return [r.scope_id for r in timer_requests] - def _get_requests(self, fd: io.TextIOWrapper, max_interval: float) -> List[FileTimerRequest]: + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> List[FileTimerRequest]: start = time.time() requests = [] while not self._stop_signaled or self._run_once: @@ -309,7 +337,10 @@ def _get_requests(self, fd: io.TextIOWrapper, max_interval: float) -> List[FileT signal = request["signal"] requests.append( FileTimerRequest( - worker_pid=pid, scope_id=scope_id, expiration_time=expiration_time, signal=signal + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, ) ) now = time.time() @@ -333,7 +364,7 @@ def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: self._timers[key] = request def clear_timers(self, worker_pids: Set[int]) -> None: - for (pid, scope_id) in list(self._timers.keys()): + for pid, scope_id in list(self._timers.keys()): if pid in worker_pids or not FileTimerServer.is_process_running(pid): del self._timers[(pid, scope_id)] diff --git a/torch/distributed/elastic/timer/local_timer.py b/torch/distributed/elastic/timer/local_timer.py index b6a54896fc5ef0..fe784b7de46d27 100644 --- a/torch/distributed/elastic/timer/local_timer.py +++ b/torch/distributed/elastic/timer/local_timer.py @@ -14,10 +14,12 @@ from .api import RequestQueue, TimerClient, TimerRequest, TimerServer -__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer'] + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] logger = logging.getLogger(__name__) + class LocalTimerClient(TimerClient): """ Client side of ``LocalTimerServer``. This client is meant to be used @@ -101,7 +103,7 @@ def register_timers(self, timer_requests: List[TimerRequest]) -> None: self._timers[(pid, scope_id)] = request def clear_timers(self, worker_ids: Set[int]) -> None: - for (pid, scope_id) in list(self._timers.keys()): + for pid, scope_id in list(self._timers.keys()): if pid in worker_ids: self._timers.pop((pid, scope_id)) diff --git a/torch/distributed/elastic/utils/api.py b/torch/distributed/elastic/utils/api.py index e0607e9c0d5dc7..bdb8f02e0176fd 100644 --- a/torch/distributed/elastic/utils/api.py +++ b/torch/distributed/elastic/utils/api.py @@ -9,7 +9,7 @@ import os import socket from string import Template -from typing import List, Any +from typing import Any, List def get_env_variable_or_raise(env_name: str) -> str: diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 04ff2fe680f1ef..1a7ea81451f7c1 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -16,6 +16,7 @@ from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.elastic.utils.store import barrier + __all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] logger = get_logger(__name__) @@ -58,7 +59,12 @@ def create_c10d_store( " is_server : %s\n" " timeout(sec): %s\n" " use_libuv : %s\n", - server_addr, port, world_size, is_server, timeout, use_libuv, + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, ) try: @@ -90,7 +96,10 @@ def create_c10d_store( if str(e) == _ADDRESS_IN_USE: # this will only happen on the server if attempt < retries: logger.warning( - "port: %s already in use, attempt: [%s/%s]", port, attempt, retries + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, ) attempt += 1 else: diff --git a/torch/distributed/elastic/utils/store.py b/torch/distributed/elastic/utils/store.py index 6d2e1f046502b3..a94010c432b185 100644 --- a/torch/distributed/elastic/utils/store.py +++ b/torch/distributed/elastic/utils/store.py @@ -7,15 +7,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from contextlib import contextmanager from datetime import timedelta from typing import List -from contextlib import contextmanager + _NUM_MEMBERS = "/num_members" _LAST_MEMBER_CHECKIN = "/last_member" __all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + @contextmanager def store_timeout(store, timeout: float): """ @@ -52,9 +54,7 @@ def get_all(store, rank: int, prefix: str, world_size: int): value3 = values[2] # retrieves the data for key torchelastic/data2 """ - data_arr = store.multi_get( - [f"{prefix}{idx}" for idx in range(world_size)] - ) + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) barrier_key = _barrier_nonblocking( store=store, @@ -101,7 +101,6 @@ def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: num_members_key = key_prefix + _NUM_MEMBERS last_member_key = key_prefix + _LAST_MEMBER_CHECKIN - idx = store.add(num_members_key, 1) if idx == world_size: store.set(last_member_key, "") @@ -126,5 +125,7 @@ def barrier( """ with store_timeout(store, barrier_timeout): - last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix) + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) store.get(last_member_key) From 22d258427baf226fe67f888de044a62941c66dd7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 14:31:39 +0800 Subject: [PATCH 131/171] [BE][Easy] enable UFMT for `torch/distributed/_shard/` (#128867) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128867 Approved by: https://github.com/fegin ghstack dependencies: #128866 --- .lintrunner.toml | 34 -- torch/distributed/_shard/__init__.py | 7 +- torch/distributed/_shard/_utils.py | 18 +- torch/distributed/_shard/api.py | 107 +++--- .../distributed/_shard/checkpoint/__init__.py | 4 +- torch/distributed/_shard/common_op_utils.py | 9 +- torch/distributed/_shard/metadata.py | 18 +- torch/distributed/_shard/op_registry_utils.py | 11 +- .../_shard/sharded_optim/__init__.py | 12 +- torch/distributed/_shard/sharded_optim/api.py | 16 +- .../_shard/sharded_tensor/__init__.py | 192 ++++++----- .../_shard/sharded_tensor/_ops/__init__.py | 14 +- .../_shard/sharded_tensor/_ops/_common.py | 9 +- .../_shard/sharded_tensor/_ops/binary_cmp.py | 32 +- .../_shard/sharded_tensor/_ops/init.py | 31 +- .../_shard/sharded_tensor/_ops/misc_ops.py | 5 +- .../_shard/sharded_tensor/_ops/tensor_ops.py | 10 +- .../distributed/_shard/sharded_tensor/api.py | 320 ++++++++++-------- .../_shard/sharded_tensor/logger.py | 5 +- .../_shard/sharded_tensor/logging_handlers.py | 1 + .../_shard/sharded_tensor/metadata.py | 22 +- .../_shard/sharded_tensor/reshard.py | 19 +- .../_shard/sharded_tensor/shard.py | 16 +- .../_shard/sharded_tensor/utils.py | 154 ++++++--- torch/distributed/_shard/sharder.py | 3 + .../_shard/sharding_plan/__init__.py | 5 +- torch/distributed/_shard/sharding_plan/api.py | 6 +- .../_shard/sharding_spec/__init__.py | 10 +- .../_shard/sharding_spec/_internals.py | 25 +- torch/distributed/_shard/sharding_spec/api.py | 77 +++-- .../sharding_spec/chunk_sharding_spec.py | 77 +++-- torch/distributed/_spmd/api.py | 3 - torch/distributed/_spmd/batch_dim_utils.py | 5 +- torch/distributed/_spmd/config.py | 1 + torch/distributed/_spmd/data_parallel.py | 6 +- torch/distributed/_spmd/distribute.py | 2 - torch/distributed/_spmd/experimental_ops.py | 2 +- torch/distributed/_spmd/graph_optimization.py | 1 + torch/distributed/_spmd/parallel_mode.py | 1 - torch/distributed/_spmd/partial_lower.py | 3 +- 40 files changed, 736 insertions(+), 557 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2ea1579ee64c27..dc9f9ddd46c7ce 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1392,40 +1392,6 @@ exclude_patterns = [ 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable 'torch/distributed/__init__.py', 'torch/distributed/_composable_state.py', - 'torch/distributed/_shard/__init__.py', - 'torch/distributed/_shard/_utils.py', - 'torch/distributed/_shard/api.py', - 'torch/distributed/_shard/checkpoint/__init__.py', - 'torch/distributed/_shard/common_op_utils.py', - 'torch/distributed/_shard/metadata.py', - 'torch/distributed/_shard/op_registry_utils.py', - 'torch/distributed/_shard/sharded_optim/__init__.py', - 'torch/distributed/_shard/sharded_optim/api.py', - 'torch/distributed/_shard/sharded_tensor/__init__.py', - 'torch/distributed/_shard/sharded_tensor/_ops/__init__.py', - 'torch/distributed/_shard/sharded_tensor/_ops/_common.py', - 'torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py', - 'torch/distributed/_shard/sharded_tensor/_ops/init.py', - 'torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py', - 'torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py', - 'torch/distributed/_shard/sharded_tensor/api.py', - 'torch/distributed/_shard/sharded_tensor/logger.py', - 'torch/distributed/_shard/sharded_tensor/logging_handlers.py', - 'torch/distributed/_shard/sharded_tensor/metadata.py', - 'torch/distributed/_shard/sharded_tensor/reshard.py', - 'torch/distributed/_shard/sharded_tensor/shard.py', - 'torch/distributed/_shard/sharded_tensor/utils.py', - 'torch/distributed/_shard/sharder.py', - 'torch/distributed/_shard/sharding_plan/__init__.py', - 'torch/distributed/_shard/sharding_plan/api.py', - 'torch/distributed/_shard/sharding_spec/__init__.py', - 'torch/distributed/_shard/sharding_spec/_internals.py', - 'torch/distributed/_shard/sharding_spec/api.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py', - 'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py', 'torch/distributed/_sharded_tensor/__init__.py', 'torch/distributed/_sharding_spec/__init__.py', 'torch/distributed/_tools/__init__.py', diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py index 34539d633f8fa0..85a313c779e7aa 100644 --- a/torch/distributed/_shard/__init__.py +++ b/torch/distributed/_shard/__init__.py @@ -1,6 +1 @@ -from .api import ( - _shard_tensor, - load_with_process_group, - shard_module, - shard_parameter, -) +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/torch/distributed/_shard/_utils.py b/torch/distributed/_shard/_utils.py index 26305b99cce306..d06fc4dc961447 100644 --- a/torch/distributed/_shard/_utils.py +++ b/torch/distributed/_shard/_utils.py @@ -1,10 +1,17 @@ +from typing import Sequence + import torch from torch.distributed._shard.metadata import ShardMetadata -from typing import Sequence + DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." -def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor: + +def narrow_tensor_by_index( + tensor: torch.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> torch.Tensor: """ Narrow the tensor according to ``offsets`` and ``sizes``. """ @@ -14,13 +21,10 @@ def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a # leaf variable in the autograd graph. - narrowed_tensor = narrowed_tensor.narrow( - idx, - offset, - size - ) + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) return narrowed_tensor + def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: """ Narrow the tensor according to the metadata diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 441bb421b195b1..5d1e1179c9a780 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -1,21 +1,17 @@ # mypy: allow-untyped-defs from contextlib import contextmanager from typing import Optional + import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import distributed_c10d -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, -) -from .sharding_spec import ( - ShardingSpec, - ChunkShardingSpec -) -from .sharding_plan import ( - ShardingPlan -) +from torch.distributed._shard.sharded_tensor import ShardedTensor + from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + def _shard_tensor( tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None @@ -47,9 +43,13 @@ def _shard_tensor( currently supported as the ``sharding_spec``. """ if not tensor.is_contiguous(): - raise ValueError('input tensor is not a contiguous Tensor') + raise ValueError("input tensor is not a contiguous Tensor") - pg = process_group if process_group is not None else distributed_c10d._get_default_group() + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) world_size = dist.get_world_size(pg) current_rank = dist.get_rank(pg) @@ -60,23 +60,27 @@ def _shard_tensor( for idx, entry in enumerate(gathered_list): if src_rank != entry[0]: # type: ignore[index] raise ValueError( - f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index] - f'match with src_rank={entry[0]} on rank: {idx}') + f"src_rank={src_rank} on rank: {current_rank} does not " + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) if sharding_spec != entry[1]: # type: ignore[index] raise ValueError( - f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index] - f'match with sharding_spec={entry[1]} on rank: {idx}') + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) return st + def shard_parameter( - module: torch.nn.Module, - param_name: str, - sharding_spec: ShardingSpec, - src_rank=0, - process_group=None): + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): """ Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that module, it shards that parameter according to the provided @@ -107,23 +111,27 @@ def shard_parameter( """ # Perform some validation first. if not hasattr(module, param_name): - raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`') + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): - raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') + raise ValueError(f"param: {param_name} is not a contiguous Tensor") st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) # Replace param with ShardedTensor. module.register_parameter(param_name, nn.Parameter(st)) + # Tracks the current process group in the load context manager. _CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None + @contextmanager def load_with_process_group(process_group): """ @@ -133,13 +141,15 @@ def load_with_process_group(process_group): if _CURRENT_PROCESS_GROUP is not None: raise RuntimeError( 'ProcessGroup already set by previous "load_with_process_group" ' - 'context manager') + "context manager" + ) _CURRENT_PROCESS_GROUP = process_group try: yield process_group finally: _CURRENT_PROCESS_GROUP = None + def _get_current_process_group(): """ Retrieves the current process group set by ``load_with_process_group``. @@ -151,9 +161,10 @@ def _get_current_process_group(): else: return _CURRENT_PROCESS_GROUP + def _reshard_output( - module: torch.nn.Module, - resharding_spec: ShardingSpec) -> torch.nn.Module: + module: torch.nn.Module, resharding_spec: ShardingSpec +) -> torch.nn.Module: """ Hook a module with output resharding in the forward pass according to the given ``resharding_spec``. @@ -166,13 +177,16 @@ def _reshard_output( Returns: A :class:`torch.nn.Module` object with reshard API hooked. """ + def hook_func(_module, _input, output): if isinstance(output, ShardedTensor): return output.reshard(resharding_spec) return output + module.register_forward_hook(hook_func) return module + def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: """ Hook a module with local shards collection in the forward pass. @@ -196,21 +210,20 @@ def hook_func(_module, _input, output): local_tensor = output.local_tensor() # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec sharding_spec = output._sharding_spec - if isinstance(sharding_spec, ChunkShardingSpec) \ - and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type] + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): local_tensor = local_tensor.squeeze( output._sharding_spec.dim # type: ignore[attr-defined] ) return local_tensor + module.register_forward_hook(hook_func) return module -def shard_module( - module: nn.Module, - plan: ShardingPlan, - src_rank=0, - process_group=None -): + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): """ Shards a given module according to the provided sharding `plan`. This method first shards all the parameters according to the given sharding `plan`. Then if @@ -249,18 +262,16 @@ def shard_module( for sharder_path in sharder_paths: if module_path.startswith(sharder_path): - raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name}," - f" but there's already a Sharder entry for module {sharder_path}," - f" parameter sharding should not conflict with the submodule tree" - f" that a Sharder is working with!") + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) mod = module.get_submodule(module_path) shard_parameter( - mod, - param_name, - spec, - src_rank=src_rank, - process_group=process_group + mod, param_name, spec, src_rank=src_rank, process_group=process_group ) elif isinstance(spec, Sharder): parent_mod_path, _, mod_name = name.rpartition(".") @@ -272,7 +283,9 @@ def shard_module( # swap this submodule with the sharded module parent_mod.mod_name = sharded_mod else: - raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'") + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) # reshard output if there's an entry in `reshard_output` for this module if plan.output_plan is not None: @@ -281,7 +294,9 @@ def shard_module( mod = module.get_submodule(module_path) _reshard_output(mod, output_spec) else: - raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'") + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) # convert the output back to data parallel for the modules appears in # `return_local_tensor` of the plan, we will call `_collect_local_shard` # to collect the local tensor for output of modules diff --git a/torch/distributed/_shard/checkpoint/__init__.py b/torch/distributed/_shard/checkpoint/__init__.py index 161a43f276d661..85915636a01464 100644 --- a/torch/distributed/_shard/checkpoint/__init__.py +++ b/torch/distributed/_shard/checkpoint/__init__.py @@ -1,9 +1,9 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed.checkpoint` package. import sys -import torch import warnings +import torch from torch.distributed.checkpoint import * # noqa: F403 @@ -16,4 +16,4 @@ stacklevel=2, ) -sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint +sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 7506f17b046d45..e2573998712b5f 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch.utils import _pytree as pytree -from typing import Optional + def _basic_validation(op, args=(), kwargs=None): """ @@ -37,14 +39,15 @@ def validate_pg(e): if isinstance(e, ShardedTensor): if cur_pg is not None and e._process_group is not cur_pg: raise RuntimeError( - 'All distributed tensors should use the ' - 'same ProcessGroup if used together in an op.' + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." ) cur_pg = e._process_group pytree.tree_map_(validate_pg, args) pytree.tree_map_(validate_pg, kwargs) + def _register_default_op(op, decorator): @decorator(op) def tensor_default_op(types, args=(), kwargs=None, pg=None): diff --git a/torch/distributed/_shard/metadata.py b/torch/distributed/_shard/metadata.py index 850b065e4dab09..2611d13ef3aafe 100644 --- a/torch/distributed/_shard/metadata.py +++ b/torch/distributed/_shard/metadata.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs from dataclasses import dataclass -from typing import List, Union, Optional from functools import reduce +from typing import List, Optional, Union from torch.distributed.remote_device import _remote_device + @dataclass class ShardMetadata: """ @@ -22,7 +23,7 @@ class ShardMetadata: Specifies the placement of this shard. """ - __slots__ = ['shard_offsets', 'shard_sizes', 'placement'] + __slots__ = ["shard_offsets", "shard_sizes", "placement"] shard_offsets: List[int] shard_sizes: List[int] @@ -32,7 +33,7 @@ def __init__( self, shard_offsets: List[int], shard_sizes: List[int], - placement: Optional[Union[str, _remote_device]] = None + placement: Optional[Union[str, _remote_device]] = None, ): self.shard_offsets = shard_offsets self.shard_sizes = shard_sizes @@ -42,15 +43,16 @@ def __init__( self.placement = placement if len(self.shard_offsets) != len(self.shard_sizes): raise ValueError( - f'shard_offsets and shard_sizes should have ' - f'the same number of elements, found {len(self.shard_offsets)} ' - f'and {self.shard_sizes} respectively') + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) for i in range(len(self.shard_offsets)): if self.shard_offsets[i] < 0: - raise ValueError('shard_offsets should be >=0') + raise ValueError("shard_offsets should be >=0") if self.shard_sizes[i] < 0: - raise ValueError('shard_sizes should be >= 0') + raise ValueError("shard_sizes should be >= 0") def __hash__(self): def _hash_reduce(a, b): diff --git a/torch/distributed/_shard/op_registry_utils.py b/torch/distributed/_shard/op_registry_utils.py index 033dc7c58e0ad3..12e0b1895e2f05 100644 --- a/torch/distributed/_shard/op_registry_utils.py +++ b/torch/distributed/_shard/op_registry_utils.py @@ -1,13 +1,16 @@ # mypy: allow-untyped-defs import functools from inspect import signature + from .common_op_utils import _basic_validation + """ Common utilities to register ops on ShardedTensor and PartialTensor. """ + def _register_op(op, func, op_table): """ Performs basic validation and registers the provided op in the given @@ -15,12 +18,14 @@ def _register_op(op, func, op_table): """ if len(signature(func).parameters) != 4: raise TypeError( - f'Custom sharded op function expects signature: ' - f'(types, args, kwargs, process_group), but received ' - f'signature: {signature(func)}') + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) op_table[op] = func + def _decorator_func(wrapped_func, op, op_table): """ Decorator function to register the given ``op`` in the provided diff --git a/torch/distributed/_shard/sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py index 172213fb0c1713..d1508208c16907 100644 --- a/torch/distributed/_shard/sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -1,18 +1,16 @@ from typing import Iterator, Tuple, Union -from .api import ShardedOptimizer import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor -) def named_params_with_sharded_tensor( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ShardedTensor]]]: - r"""Returns an iterator over module parameters (together with the ShardedTensor parameters), yielding both the name of the parameter as well as the parameter itself. This is typically passed to a @@ -46,7 +44,7 @@ def named_params_with_sharded_tensor( for name, val in vars(mod).items(): if isinstance(val, ShardedTensor) and val not in memo: memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val # find all nn.Parameters diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index e1acf7dc17a871..1c7c632f22b59a 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Union, Mapping, Dict, Any +from typing import Any, Dict, List, Mapping, Union import torch.optim as optim from torch import Tensor @@ -12,7 +12,7 @@ def __init__( named_params: Mapping[str, Union[Tensor, ShardedTensor]], optimizer_class, *optimizer_args, - **optimizer_kwargs + **optimizer_kwargs, ): """ ShardedOptimizer collects all tensors and local shard tensors of @@ -80,7 +80,6 @@ def state_dict(self) -> Dict[str, Any]: # TODO: implement state_dict raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") - def load_state_dict(self, state_dict: Mapping[str, Any]): r"""Loads the ShardedOptimizer state. @@ -89,10 +88,13 @@ def load_state_dict(self, state_dict: Mapping[str, Any]): from a call to :meth:`state_dict`. """ # TODO: implement load_state_dict - raise NotImplementedError("ShardedOptimizer load_state_dict not implemented yet!") + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) def add_param_group(self, param_group: Any): - r"""Add a new param group - """ + r"""Add a new param group""" # TODO: implement add_param_group - raise NotImplementedError("ShardedOptimizer add_param_group not implemented yet!") + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 1b846a8dabb497..db7090820ea0a9 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -3,34 +3,37 @@ from typing import List, TYPE_CHECKING import torch - -if TYPE_CHECKING: - from torch.distributed._shard.sharding_spec import ShardingSpec -else: - ShardingSpec = "ShardingSpec" +from torch.distributed._shard.op_registry_utils import _decorator_func from .api import ( _CUSTOM_SHARDED_OPS, _SHARDED_OPS, Shard, - ShardedTensorBase, ShardedTensor, + ShardedTensorBase, ShardedTensorMetadata, TensorProperties, ) from .metadata import ShardMetadata # noqa: F401 -from torch.distributed._shard.op_registry_utils import _decorator_func -def empty(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` filled with uninitialized data. Needs to be called on all ranks in an SPMD fashion. @@ -74,15 +77,18 @@ def empty(sharding_spec: ShardingSpec, init_rrefs=init_rrefs, ) -def ones(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` with the scalar value 1. Needs to be called on all ranks in an SPMD fashion. @@ -122,18 +128,21 @@ def ones(sharding_spec: ShardingSpec, pin_memory=pin_memory, memory_format=memory_format, process_group=process_group, - init_rrefs=init_rrefs + init_rrefs=init_rrefs, ) -def zeros(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Returns a :class:`ShardedTensor` filled with the scalar value 0. Needs to be called on all ranks in an SPMD fashion. @@ -173,20 +182,23 @@ def zeros(sharding_spec: ShardingSpec, pin_memory=pin_memory, memory_format=memory_format, process_group=process_group, - init_rrefs=init_rrefs + init_rrefs=init_rrefs, ) -def full(sharding_spec: ShardingSpec, - size, - fill_value, - *, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype is inferred from fill_value. If dtype is specified, it will override the @@ -229,15 +241,18 @@ def full(sharding_spec: ShardingSpec, torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] return sharded_tensor -def rand(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the interval :math:`[0, 1)`. The shape of the tensor is defined by the @@ -282,15 +297,18 @@ def rand(sharding_spec: ShardingSpec, torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] return sharded_tensor -def randn(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: """ Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution with mean `0` and variance `1` (also called standard normal distribution). The shape @@ -336,11 +354,10 @@ def randn(sharding_spec: ShardingSpec, torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] return sharded_tensor + def init_from_local_shards( - local_shards: List[Shard], - *global_size, - process_group=None, - init_rrefs=False) -> ShardedTensor: + local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: """ Creates an :class:`ShardedTensor` from local shards and the global metadata. Needs to be called on all ranks in an SPMD fashion. @@ -388,12 +405,10 @@ def init_from_local_shards( >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) """ return ShardedTensor._init_from_local_shards( - local_shards, - *global_size, - process_group=process_group, - init_rrefs=init_rrefs + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs ) + def state_dict_hook(module, destination, prefix, local_metadata): """ Hook to add ShardedTensor to Module's ``state_dict``. Needs to be @@ -404,21 +419,32 @@ def state_dict_hook(module, destination, prefix, local_metadata): for attr_name, attr in submodule.__dict__.items(): if isinstance(attr, ShardedTensor): mod_prefix = prefix + submodule_name - key = mod_prefix + ('.' if mod_prefix else '') + attr_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name destination[key] = attr -def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): """ Pre-load state dict hook to add ShardedTensor to the module. """ for submodule_name, submodule in module.named_modules(): for attr_name in submodule.__dict__.keys(): mod_prefix = prefix + submodule_name - key = mod_prefix + ('.' if mod_prefix else '') + attr_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name if key in state_dict: if isinstance(state_dict[key], ShardedTensor): setattr(submodule, attr_name, state_dict[key]) + def custom_sharded_op_impl(func): """ Provides a way for users to write their own custom sharded operator. This @@ -450,21 +476,15 @@ def custom_sharded_op_impl(func): func(Callable): Torch function for which we want to provide a sharded implementation (ex: torch.nn.functional.linear) """ - return functools.partial( - _decorator_func, - op=func, - op_table=_CUSTOM_SHARDED_OPS - ) + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + def _sharded_op_impl(func): """ Decorator to register a default sharded op. """ - return functools.partial( - _decorator_func, - op=func, - op_table=_SHARDED_OPS - ) + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + # Import all builtin sharded ops from ._ops import * # noqa: F403 diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py index c233840f1eccee..be6d01fc8e54ee 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -1,9 +1,13 @@ import torch.distributed._shard.sharded_tensor._ops.misc_ops import torch.distributed._shard.sharded_tensor._ops.tensor_ops -from .binary_cmp import equal, allclose -from .init import kaiming_uniform_, normal_, uniform_, constant_ - # Import all ChunkShardingSpec ops -from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding -from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index 4d35d24ecafcad..502e0ac9a8552d 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import functools + +from torch.distributed._shard.common_op_utils import _basic_validation from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, Shard, ShardedTensor, ) -from torch.distributed._shard.common_op_utils import _basic_validation + def _sharded_op_common(op, early_stop_func, extra_check): """ @@ -35,6 +37,7 @@ def _sharded_op_common(op, early_stop_func, extra_check): func (Callable): Torch function for which we want to provide a sharded implementation (ex: torch.transpose) """ + def decorator_sharded_func(wrapped_func): @functools.wraps(wrapped_func) def wrapper(types, args=(), kwargs=None, pg=None): @@ -55,6 +58,7 @@ def wrapper(types, args=(), kwargs=None, pg=None): return decorator_sharded_func + def _register_sharded_op_on_local_shards( op, early_stop_func=None, extra_check=None, customized_func=None ): @@ -84,6 +88,7 @@ def _register_sharded_op_on_local_shards( func (Callable): registered implementation for sharded op for ``__torch_function__`` dispatch. """ + @_sharded_op_impl(op) @_sharded_op_common(op, early_stop_func, extra_check) def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): @@ -104,5 +109,5 @@ def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): st_metadata, process_group=pg, init_rrefs=st._init_rrefs, - sharding_spec=st.sharding_spec() + sharding_spec=st.sharding_spec(), ) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index 034f9149816120..f8db8b6ebe96fd 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -2,10 +2,8 @@ import torch import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d -from torch.distributed._shard.sharded_tensor import ( - ShardedTensor, - _sharded_op_impl -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor + def _communicate_result(result, pg): # Gather results from all ranks. @@ -16,26 +14,35 @@ def _communicate_result(result, pg): dist.all_reduce(result_tensor, group=pg) - expected_result = torch.ones(1, device=torch.device(torch.cuda.current_device())) * dist.get_world_size(pg) + expected_result = torch.ones( + 1, device=torch.device(torch.cuda.current_device()) + ) * dist.get_world_size(pg) return torch.equal(result_tensor, expected_result) + def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): if len(args) != 2: - raise ValueError(f'Expected two arguments for torch.{cmp_fun.__name__}') + raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") result = True st1 = args[0] st2 = args[1] if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): - raise TypeError(f'Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor') + raise TypeError( + f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor" + ) # Verify same PG if st1._process_group != st2._process_group: return False - if distributed_c10d._rank_not_in_group(st1._process_group) or distributed_c10d._rank_not_in_group(st2._process_group): - return distributed_c10d._rank_not_in_group(st1._process_group) == distributed_c10d._rank_not_in_group(st2._process_group) + if distributed_c10d._rank_not_in_group( + st1._process_group + ) or distributed_c10d._rank_not_in_group(st2._process_group): + return distributed_c10d._rank_not_in_group( + st1._process_group + ) == distributed_c10d._rank_not_in_group(st2._process_group) # Verify metadata if st1.metadata() != st2.metadata(): @@ -54,16 +61,19 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): for idx in range(len(st1_local_shards)): if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: return _communicate_result(False, st1._process_group) - if not cmp_fun(st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs): + if not cmp_fun( + st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs + ): return _communicate_result(False, st1._process_group) - return _communicate_result(True, st1._process_group) + @_sharded_op_impl(torch.equal) def equal(types, args, kwargs, process_group): return binary_cmp(torch.equal, types, args, kwargs, process_group) + @_sharded_op_impl(torch.allclose) def allclose(types, args, kwargs, process_group): return binary_cmp(torch.allclose, types, args, kwargs, process_group) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index 736190d491e1e2..71a9c20b45352c 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-defs import torch import torch.distributed._shard.sharded_tensor as sharded_tensor -from torch.distributed._shard.sharded_tensor import ( - _sharded_op_impl, -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + def validate_param(param, param_name): if param is None: raise ValueError(f"param: {param_name} shouldn't be None!") + @_sharded_op_impl(torch.nn.init.uniform_) def uniform_(types, args=(), kwargs=None, pg=None): r""" @@ -22,15 +22,16 @@ def uniform_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - a = kwargs['a'] + a = kwargs["a"] validate_param(a, "a") - b = kwargs['b'] + b = kwargs["b"] validate_param(b, "b") for shard in sharded_tensor.local_shards(): torch.nn.init.uniform_(shard.tensor, a=a, b=b) return sharded_tensor + @_sharded_op_impl(torch.nn.init.normal_) def normal_(types, args=(), kwargs=None, pg=None): r""" @@ -44,15 +45,16 @@ def normal_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - mean = kwargs['mean'] + mean = kwargs["mean"] validate_param(mean, "mean") - std = kwargs['std'] + std = kwargs["std"] validate_param(std, "std") for shard in sharded_tensor.local_shards(): torch.nn.init.normal_(shard.tensor, mean=mean, std=std) return sharded_tensor + @_sharded_op_impl(torch.nn.init.kaiming_uniform_) def kaiming_uniform_(types, args=(), kwargs=None, pg=None): r""" @@ -78,17 +80,20 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - a = kwargs['a'] + a = kwargs["a"] validate_param(a, "a") - mode = kwargs['mode'] + mode = kwargs["mode"] validate_param(mode, "mode") - nonlinearity = kwargs['nonlinearity'] + nonlinearity = kwargs["nonlinearity"] validate_param(nonlinearity, "nonlinearity") for shard in sharded_tensor.local_shards(): - torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity) + torch.nn.init.kaiming_uniform_( + shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity + ) return sharded_tensor + @_sharded_op_impl(torch.nn.init.constant_) def constant_(types, args=(), kwargs=None, pg=None): r""" @@ -100,12 +105,13 @@ def constant_(types, args=(), kwargs=None, pg=None): validate_param(kwargs, "kwargs") sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - val = kwargs['val'] + val = kwargs["val"] validate_param(val, "val") for shard in sharded_tensor.local_shards(): torch.nn.init.constant_(shard.tensor, val=val) return sharded_tensor + tensor_like_creation_op_map = { torch.full_like: sharded_tensor.full, torch.empty_like: sharded_tensor.empty, @@ -115,6 +121,7 @@ def constant_(types, args=(), kwargs=None, pg=None): torch.randn_like: sharded_tensor.randn, } + # tensor ops that behave the same as the default tensor def register_tensor_creation_op(op): @_sharded_op_impl(op) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py index 82737f82de5339..8b84c1684c3245 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -1,8 +1,7 @@ # mypy: allow-untyped-defs import torch -from torch.distributed._shard.sharded_tensor import ( - _sharded_op_impl, -) +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + # This is used by `_apply()` within module.py to set new # parameters after apply a certain method, we should follow diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 7de78bf61f3f1b..93902d6f314c51 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -1,15 +1,15 @@ # mypy: allow-untyped-defs import copy + import torch +from torch.distributed._shard.common_op_utils import _register_default_op from torch.distributed._shard.sharded_tensor import ( _sharded_op_impl, Shard, ShardedTensor, ) -from ._common import ( - _register_sharded_op_on_local_shards, -) -from torch.distributed._shard.common_op_utils import _register_default_op + +from ._common import _register_sharded_op_on_local_shards # Tensor properties access @@ -33,6 +33,7 @@ _register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr] _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + # device property is ambiguous as from a global prospective, # ShardedTensor.device consists of multiple devices (might even across hosts) # We choose to return the current device of the local tensor to represent @@ -52,6 +53,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None): dev = torch.device(torch.cuda.current_device()) return dev + @_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] def st_is_meta(types, args=(), kwargs=None, pg=None): return args[0].local_tensor().is_meta diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index bf5db21b9a16c9..68df582cd5145e 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1,57 +1,48 @@ # mypy: allow-untyped-defs from __future__ import annotations # type: ignore[attr-defined] -from dataclasses import dataclass -from typing import ( - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, - cast, - TYPE_CHECKING, -) -from typing_extensions import deprecated + import copy +import operator +import threading import warnings -from functools import reduce import weakref +from dataclasses import dataclass +from functools import reduce +from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing_extensions import deprecated -import threading import torch import torch.distributed as dist -from torch.distributed import rpc -from torch.distributed import distributed_c10d import torch.distributed._shard.sharding_spec as shard_spec -from torch.distributed._shard.sharding_spec.api import ( - _dispatch_custom_op, - _has_custom_op, -) +from torch.distributed import distributed_c10d, rpc +from torch.distributed._shard._utils import DEPRECATE_MSG from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) -from torch.distributed._shard._utils import ( - DEPRECATE_MSG, +from torch.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, ) +from torch.distributed.remote_device import _remote_device +from torch.utils import _pytree as pytree -from .metadata import TensorProperties, ShardedTensorMetadata +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard from .shard import Shard -from .reshard import reshuffle_local_shard, reshard_local_shard from .utils import ( _flatten_tensor_size, _parse_and_validate_remote_device, _validate_output_tensor_for_gather, + build_global_metadata, build_metadata_from_local_shards, - build_global_metadata ) -from torch.distributed.remote_device import _remote_device -from torch.utils import _pytree as pytree -import operator + if TYPE_CHECKING: from torch.distributed._shard.metadata import ShardMetadata + # Tracking for sharded tensor objects. _sharded_tensor_lock = threading.Lock() _sharded_tensor_current_id = 0 @@ -63,18 +54,23 @@ # Customized user ops _CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} -def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int): + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int +): with _sharded_tensor_lock: if sharded_tensor_id not in _sharded_tensor_map: raise RuntimeError( - f'Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}') + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() if sharded_tensor is None: - raise RuntimeError('ShardedTensor weakref has been deallocated') + raise RuntimeError("ShardedTensor weakref has been deallocated") else: sharded_tensor._register_remote_shards(rrefs, rpc_rank) + class ShardedTensorBase(torch.Tensor): _sharding_spec: shard_spec.ShardingSpec _metadata: ShardedTensorMetadata @@ -191,6 +187,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): "but the there is no custom __torch_dispatch__ implementation for it." ) + class ShardedTensor(ShardedTensorBase): """ ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded @@ -239,6 +236,7 @@ class ShardedTensor(ShardedTensorBase): individual GPU, via ``torch.cuda.set_device()`` """ + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): self = super().__new__(cls, sharding_spec, *size, **kwargs) return self @@ -260,22 +258,26 @@ def __init__( self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) if layout != torch.strided: - raise ValueError('Only torch.strided layout is currently supported') + raise ValueError("Only torch.strided layout is currently supported") if memory_format != torch.contiguous_format: - raise ValueError('Only torch.contiguous_format memory_format is currently supported') + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) self._metadata.tensor_properties.memory_format = memory_format current_rank = dist.get_rank() # global rank for shard_metadata in self._metadata.shards_metadata: - rank, device = _parse_and_validate_remote_device(self._process_group, shard_metadata.placement) + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) if rank == current_rank: local_tensor = _create_tensor_from_params( shard_metadata.shard_sizes, local_device=device, - tensor_properties=self._metadata.tensor_properties + tensor_properties=self._metadata.tensor_properties, ) self._local_shards.append(Shard(local_tensor, shard_metadata)) @@ -300,8 +302,9 @@ def _post_init(self): if not rpc._is_current_rpc_agent_set(): raise RuntimeError( - 'RPC Framework needs to be initialized using' - ' torch.distributed.rpc.init_rpc if init_rrefs is set to True') + "RPC Framework needs to be initialized using" + " torch.distributed.rpc.init_rpc if init_rrefs is set to True" + ) self._init_rpc() def __del__(self): @@ -320,9 +323,9 @@ def _init_rpc(self): rpc_rank = rpc.get_worker_info().id if pg_rank != rpc_rank: raise ValueError( - f'Default ProcessGroup and RPC ranks must be ' - f'the same for ShardedTensor, found process group rank: ' - f'{pg_rank} and RPC rank: {rpc_rank}' + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" ) self._remote_shards = {} @@ -347,11 +350,14 @@ def _init_rpc(self): continue if len(self.local_shards()) != 0: - rrefs: List[rpc.RRef[Shard]] = [rpc.RRef(shard) for shard in self.local_shards()] + rrefs: List[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] fut = rpc.rpc_async( rank, _register_remote_shards, - args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank)) + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) futs.append(fut) torch.futures.wait_all(futs) @@ -394,6 +400,7 @@ def gather( # type: ignore[override] dtype (torch.dtype): Force the gathered tensors to be this dtype. Default: ``None`` """ + def shard_size(shard_md): return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] @@ -429,7 +436,10 @@ def shard_size(shard_md): # enforce_dtype is deprecated. Do it for backward compatibility. dtype = out.dtype # TODO make it as a view of out tensor - gather_list = [torch.empty((max_rank_size,), device=out.device, dtype=dtype) for _ in range(world_size)] + gather_list = [ + torch.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] else: gather_list = None @@ -437,15 +447,19 @@ def shard_size(shard_md): if enforce_dtype and len(local_shards) > 0: # enforce_dtype is deprecated. Do it for backward compatibility. dtype = local_shards[0].tensor.dtype - data = torch.empty(max_rank_size, device=self._get_preferred_device(), dtype=dtype) + data = torch.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) for shard in local_shards: src = shard.tensor.flatten() - if src.nelement() == 0 : - warnings.warn("Gathering a tensor with zero elements on rank " + str(rank)) + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank) + ) return shard_offset = shard_placement[shard.metadata][1] - data[shard_offset: shard_offset + src.numel()].copy_(src) + data[shard_offset : shard_offset + src.numel()].copy_(src) dist.gather( tensor=data, @@ -478,9 +492,7 @@ def shard_size(shard_md): out_narrow_view.copy_(tensor) def cpu( - self, - memory_format=torch.preserve_format, - process_group=None + self, memory_format=torch.preserve_format, process_group=None ) -> ShardedTensor: """ Returns a copy of this object in CPU memory. @@ -495,13 +507,17 @@ def cpu( """ # TODO: make this a __torch_function__ op once ShardedTensor becomes a # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 - if memory_format != torch.preserve_format and \ - memory_format != torch.contiguous_format: - raise RuntimeError("Only `torch.contiguous_format` or " - "`torch.preserve_format` is supported!") + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) all_on_cpu = True for meta in self.metadata().shards_metadata: - all_on_cpu &= (meta.placement.device().type == "cpu") # type: ignore[union-attr] + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] # if every shard is already on CPU, return the original object if all_on_cpu: @@ -514,9 +530,7 @@ def cpu( cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] metadata = copy.deepcopy(shard.metadata) metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] - list_shards.append( - Shard(cpu_tensor, metadata) - ) + list_shards.append(Shard(cpu_tensor, metadata)) st_meta = copy.deepcopy(self.metadata()) for meta in st_meta.shards_metadata: @@ -528,7 +542,7 @@ def cpu( list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_cpu @@ -537,7 +551,7 @@ def cuda( device=None, non_blocking=False, memory_format=torch.preserve_format, - process_group=None + process_group=None, ) -> ShardedTensor: """ Returns a copy of this object in CUDA memory, if the original ShardedTensor @@ -551,15 +565,21 @@ def cuda( it is the user's responsiblity to explicitly pass in a new process_group that is compatible with GPU. """ - if memory_format != torch.preserve_format and \ - memory_format != torch.contiguous_format: - raise RuntimeError("Only `torch.contiguous_format` or " - "`torch.preserve_format` is supported!") + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) if device is not None: device = torch.device(device) if isinstance(device, str) else device - assert isinstance(device, torch.device) and device.index == torch.cuda.current_device(), \ - '''Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!''' + assert ( + isinstance(device, torch.device) + and device.index == torch.cuda.current_device() + ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" current_device = torch.device(torch.cuda.current_device()) # returns a copy of ShardedTensor on CUDA current device @@ -571,14 +591,12 @@ def cuda( cuda_tensor = shard.tensor.cuda( device=current_device, non_blocking=non_blocking, - memory_format=memory_format + memory_format=memory_format, ) # type: ignore[call-arg] metadata = copy.deepcopy(shard.metadata) metadata.placement._device = current_device # type: ignore[union-attr] - list_shards.append( - Shard(cuda_tensor, metadata) - ) + list_shards.append(Shard(cuda_tensor, metadata)) st_meta = copy.deepcopy(self.metadata()) for meta in st_meta.shards_metadata: @@ -592,7 +610,7 @@ def cuda( list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_cuda @@ -625,15 +643,19 @@ def to(self, *args, **kwargs) -> ShardedTensor: dtype_to = kwargs.get("dtype", current_dtype) device_to = kwargs.get("device", current_device) - device_to = torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + device_to = ( + torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) if device_to.type == "cuda": # if device_to set to cuda, set to current device even # if user specify the device index. current_idx = torch.cuda.current_device() if device_to.index != current_idx: - warnings.warn("ShardedTensor.to only move tensor to its current device" - "If you want to put to different device, use `reshard` instead.") + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead." + ) device_to = torch.device(current_idx) copy_tensor = kwargs.get("copy", False) @@ -641,7 +663,11 @@ def to(self, *args, **kwargs) -> ShardedTensor: memory_format = kwargs.get("memory_format", torch.preserve_format) process_group = kwargs.get("process_group", None) - if not copy_tensor and dtype_to == current_dtype and device_to == current_device: + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): # already have correct dtype and device, return itself return self @@ -654,7 +680,7 @@ def to(self, *args, **kwargs) -> ShardedTensor: dtype=dtype_to, non_blocking=non_blocking, copy=copy_tensor, - memory_format=memory_format + memory_format=memory_format, ) metadata = copy.deepcopy(shard.metadata) if metadata.placement is not None: @@ -674,12 +700,14 @@ def to(self, *args, **kwargs) -> ShardedTensor: list_shards, sharded_tensor_metadata=st_meta, process_group=pg, - init_rrefs=self._init_rrefs + init_rrefs=self._init_rrefs, ) return st_to @classmethod - def _normalize_pg(cls, process_group: Optional[dist.ProcessGroup]) -> dist.ProcessGroup: + def _normalize_pg( + cls, process_group: Optional[dist.ProcessGroup] + ) -> dist.ProcessGroup: if process_group is not None: return process_group return distributed_c10d._get_default_group() @@ -701,8 +729,9 @@ def _init_from_local_shards( global_tensor_size = _flatten_tensor_size(global_size) if len(local_shards) > 0: - local_sharded_tensor_metadata = \ - build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group) + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata @@ -711,9 +740,7 @@ def _init_from_local_shards( gathered_metadatas = [None for _ in range(world_size)] dist.all_gather_object( - gathered_metadatas, - local_sharded_tensor_metadata, - group=process_group + gathered_metadatas, local_sharded_tensor_metadata, group=process_group ) else: gathered_metadatas = [local_sharded_tensor_metadata] @@ -726,13 +753,15 @@ def _init_from_local_shards( spec = shard_spec._infer_sharding_spec_from_shards_metadata( global_sharded_tensor_metadata.shards_metadata ) - sharded_tensor = cls.__new__(cls, - spec, - global_sharded_tensor_metadata.size, - dtype=tensor_properties.dtype, - layout=tensor_properties.layout, - pin_memory=tensor_properties.pin_memory, - requires_grad=tensor_properties.requires_grad) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) # attach local_shards to the ShardedTensor created @@ -809,7 +838,7 @@ def _init_from_local_tensor( sharding spec. """ if not local_tensor.is_contiguous(): - raise ValueError('local_tensor is not a contiguous Tensor.') + raise ValueError("local_tensor is not a contiguous Tensor.") global_tensor_size = _flatten_tensor_size(global_size) tensor_properties = TensorProperties( @@ -817,10 +846,10 @@ def _init_from_local_tensor( layout=local_tensor.layout, requires_grad=local_tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=local_tensor.is_pinned()) + pin_memory=local_tensor.is_pinned(), + ) sharded_tensor_metadata = sharding_spec.build_metadata( - global_tensor_size, - tensor_properties + global_tensor_size, tensor_properties ) process_group = cls._normalize_pg(process_group) @@ -828,7 +857,9 @@ def _init_from_local_tensor( local_shards: List[Shard] = [] for shard_metadata in sharded_tensor_metadata.shards_metadata: - rank, device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) + rank, device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) if rank == current_rank: local_shards.append(Shard(local_tensor, shard_metadata)) @@ -868,16 +899,18 @@ def _init_from_local_shards_and_global_metadata( # type: ignore[override] # collect local shard metadatas from the global sharded_tensor_metadata for shard_metadata in shards_metadata: # type: ignore[attr-defined] - rank, local_device = _parse_and_validate_remote_device(process_group, shard_metadata.placement) + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) if current_rank == rank: local_shard_metadatas.append(shard_metadata) if len(local_shards) != len(local_shard_metadatas): raise RuntimeError( - f'Number of local shards ({len(local_shards)}) does not match number of local ' - f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) ' - f'on rank ({current_rank}) ' + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " ) shards_metadata = sharded_tensor_metadata.shards_metadata @@ -1056,12 +1089,11 @@ def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 """ - if ( - not isinstance(resharding_spec, shard_spec.ChunkShardingSpec) or - not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec) - ): + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") - if (len(self.local_shards()) != 1): + if len(self.local_shards()) != 1: raise NotImplementedError("Only single local shard supported for reshard.") if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] @@ -1110,12 +1142,7 @@ def dispatch(st: ShardedTensor, func: Callable): # Dispatch to custom sharding spec op if it has one. if _has_custom_op(st._sharding_spec, func): return _dispatch_custom_op( - st._sharding_spec, - func, - types, - args, - kwargs, - st._process_group + st._sharding_spec, func, types, args, kwargs, st._process_group ) if func in _SHARDED_OPS: @@ -1123,7 +1150,8 @@ def dispatch(st: ShardedTensor, func: Callable): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ShardedTensor!") + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) # Find ShardedTensor instance to get process_group and sharding_spec. st_instance = None @@ -1141,7 +1169,8 @@ def find_sharded_tensor(e): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " - f"kwargs: {kwargs} not supported for ShardedTensor!") + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) def is_pinned(self) -> bool: # type: ignore[override] """ @@ -1149,7 +1178,9 @@ def is_pinned(self) -> bool: # type: ignore[override] """ return self._metadata.tensor_properties.pin_memory - def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int): + def _register_remote_shards( + self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int + ): self._remote_shards[rpc_rank] = remote_shards def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: @@ -1162,7 +1193,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: """ if not self._init_rrefs: raise RuntimeError( - 'ShardedTensor created with init_rrefs=False, no RRefs to remote shards available' + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" ) return self._remote_shards @@ -1170,13 +1201,14 @@ def __hash__(self): return id(self) def __repr__(self): - return f'ShardedTensor({self._metadata})' + return f"ShardedTensor({self._metadata})" @dataclass class ProcessGroupState: """ State for ser-de of process group """ + local_rank: int global_rank: int local_world_size: int @@ -1190,51 +1222,71 @@ def __getstate__(self): distributed_c10d.get_world_size(), ) - return self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) def __setstate__(self, state): self._sharded_tensor_id = None if not distributed_c10d.is_initialized(): raise RuntimeError( - 'Need to initialize default process group using ' - '"init_process_group" before loading ShardedTensor') + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) - self._local_shards, self._metadata, pg_state, self._sharding_spec, self._init_rrefs = state + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state # Setup process group from torch.distributed._shard.api import _get_current_process_group + self._process_group = _get_current_process_group() # Validate process group. local_rank = distributed_c10d.get_rank(self._process_group) if pg_state.local_rank != local_rank: raise RuntimeError( - f'Local rank at save time was {pg_state.local_rank}, but at ' - f'load time was {local_rank}') + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) global_rank = distributed_c10d.get_rank() if pg_state.global_rank != global_rank: raise RuntimeError( - f'Global rank at save time was {pg_state.global_rank}, but at ' - f'load time was {global_rank}') + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) local_world_size = distributed_c10d.get_world_size(self._process_group) if pg_state.local_world_size != local_world_size: raise RuntimeError( - f'Local world size at save time was {pg_state.local_world_size}, ' - f'but at load time was {local_world_size}') + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) global_world_size = distributed_c10d.get_world_size() if pg_state.global_world_size != global_world_size: raise RuntimeError( - f'Global world size at save time was {pg_state.global_world_size}, ' - f'but at load time was {global_world_size}') + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) self._post_init() -def _create_tensor_from_params(*size, local_device, tensor_properties: TensorProperties): - """ Helper to construct tensor from size, device and common params. """ +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" dtype = tensor_properties.dtype layout = tensor_properties.layout requires_grad = tensor_properties.requires_grad @@ -1242,7 +1294,11 @@ def _create_tensor_from_params(*size, local_device, tensor_properties: TensorPro pin_memory = tensor_properties.pin_memory return torch.empty( - *size, dtype=dtype, layout=layout, - device=local_device, requires_grad=requires_grad, - memory_format=memory_format, pin_memory=pin_memory + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, ) diff --git a/torch/distributed/_shard/sharded_tensor/logger.py b/torch/distributed/_shard/sharded_tensor/logger.py index 87cb74fbd01d20..ebb749dc7d5c73 100644 --- a/torch/distributed/_shard/sharded_tensor/logger.py +++ b/torch/distributed/_shard/sharded_tensor/logger.py @@ -9,9 +9,8 @@ import logging from typing import List, Tuple -from torch.distributed._shard.sharded_tensor.logging_handlers import ( - _log_handlers, -) +from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + __all__: List[str] = [] diff --git a/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/torch/distributed/_shard/sharded_tensor/logging_handlers.py index 3c607fe45da771..021ad100f06a89 100644 --- a/torch/distributed/_shard/sharded_tensor/logging_handlers.py +++ b/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -9,6 +9,7 @@ import logging from typing import Dict, List + __all__: List[str] = [] _log_handlers: Dict[str, logging.Handler] = { diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index 8b3257240e3834..e53ac25fa55d9c 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -6,14 +6,16 @@ import torch from torch.distributed._shard.metadata import ShardMetadata + class MEM_FORMAT_ENCODING(Enum): TORCH_CONTIGUOUS_FORMAT = 0 TORCH_CHANNELS_LAST = 1 TORCH_PRESERVE_FORMAT = 2 + @dataclass class TensorProperties: - """ Properties used to create :class:`Tensor` """ + """Properties used to create :class:`Tensor`""" # Regular tensor fields dtype: torch.dtype = field(default=torch.get_default_dtype()) @@ -32,7 +34,7 @@ def __getstate__(self): elif memory_format == torch.preserve_format: mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT else: - raise RuntimeError(f'Invalid torch.memory_format: {memory_format}') + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") return ( self.dtype, @@ -46,7 +48,13 @@ def __setstate__( self, state, ): - (self.dtype, self.layout, self.requires_grad, mem_format_encoding, self.pin_memory) = state + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: memory_format = torch.contiguous_format @@ -55,7 +63,9 @@ def __setstate__( elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: memory_format = torch.preserve_format else: - raise RuntimeError(f'Invalid torch.memory_format encoding: {mem_format_encoding}') + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) self.memory_format = memory_format @@ -66,8 +76,10 @@ def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned() + pin_memory=tensor.is_pinned(), ) + + @dataclass class ShardedTensorMetadata: """ diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index 549dde38cdf8a8..9a82012d59cd3f 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -4,19 +4,14 @@ import torch import torch.distributed as dist -from torch._C._distributed_c10d import ( - ProcessGroup, -) import torch.distributed._shard.sharding_spec as shard_spec +from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharding_spec._internals import ( - get_split_size, get_chunked_dim_size, + get_split_size, ) -from torch.distributed.nn.functional import ( - all_to_all, - all_to_all_single, -) -from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.nn.functional import all_to_all, all_to_all_single from .shard import Shard @@ -42,7 +37,7 @@ def get_idx_from_placements(placements, current_rank) -> int: for idx, placement in enumerate(placements): # type: ignore[attr-defined] if current_rank == placement.rank(): # type: ignore[union-attr] return idx - raise RuntimeError('current_rank not in the placement.') + raise RuntimeError("current_rank not in the placement.") def build_reshard_metadata( @@ -138,7 +133,9 @@ def reshuffle_local_shard( local_shard = local_shard.transpose(0, reshard_dim).contiguous() gathered_input_size = list(local_shard.size()) gathered_input_size[0] = sharded_dim_size - gathered_input = torch.empty(gathered_input_size, device=local_shard.device, dtype=local_shard.dtype) + gathered_input = torch.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) # all2all. local_shard = all_to_all_single( gathered_input, diff --git a/torch/distributed/_shard/sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py index ac1e881370e814..dcb6b3b5d62674 100644 --- a/torch/distributed/_shard/sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -18,7 +18,8 @@ class Shard: metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): The metadata for the shard, including offsets, lengths and device placement. """ - __slots__ = ['tensor', 'metadata'] + + __slots__ = ["tensor", "metadata"] tensor: torch.Tensor metadata: ShardMetadata @@ -31,7 +32,10 @@ def __post_init__(self): f"metadata.shard_lengths: {self.metadata.shard_sizes}, " ) placement_device = self.metadata.placement - if placement_device is not None and placement_device.device() != self.tensor.device: + if ( + placement_device is not None + and placement_device.device() != self.tensor.device + ): raise ValueError( f"Local shard tensor device does not match with local Shard's placement! " f"Found local shard tensor device: {self.tensor.device}, " @@ -39,7 +43,9 @@ def __post_init__(self): ) @classmethod - def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int): + def from_tensor_and_offsets( + cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int + ): """ Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. @@ -52,8 +58,6 @@ def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], shard_sizes = list(tensor.size()) placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") shard_meta = ShardMetadata( - shard_offsets=shard_offsets, - shard_sizes=shard_sizes, - placement=placement + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement ) return Shard(tensor, shard_meta) diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index 782def0e4d4c2f..a6954813f82b27 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -1,22 +1,23 @@ # mypy: allow-untyped-defs import collections.abc import copy -from typing import Optional, List, Sequence, TYPE_CHECKING +from typing import List, Optional, Sequence, TYPE_CHECKING import torch -from torch.distributed import distributed_c10d as c10d -from torch.distributed import rpc +from torch.distributed import distributed_c10d as c10d, rpc from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) -from .metadata import TensorProperties, ShardedTensorMetadata +from .metadata import ShardedTensorMetadata, TensorProperties from .shard import Shard + if TYPE_CHECKING: from torch.distributed._shard.metadata import ShardMetadata + def _parse_and_validate_remote_device(pg, remote_device): if remote_device is None: raise ValueError("remote device is None") @@ -48,6 +49,7 @@ def _parse_and_validate_remote_device(pg, remote_device): return rank, device + def _validate_output_tensor_for_gather( my_rank: int, dst_rank: int, @@ -66,10 +68,10 @@ def _validate_output_tensor_for_gather( ) elif dst_tensor: raise ValueError( - "Argument ``dst_tensor`` must NOT be specified " - "on non-destination ranks." + "Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks." ) + def _flatten_tensor_size(size) -> torch.Size: """ Checks if tensor size is valid, then flatten/return a torch.Size object. @@ -81,33 +83,37 @@ def _flatten_tensor_size(size) -> torch.Size: for dim in dims: if not isinstance(dim, int): - raise TypeError(f'size has to be a sequence of ints, found: {dims}') + raise TypeError(f"size has to be a sequence of ints, found: {dims}") return torch.Size(dims) + def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): if is_local: assert isinstance(ranks, int) if expected != actual: - raise ValueError(f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " - f"Found one local shard tensor {prop_name}={expected}, " - f"the other local shard tensor {prop_name}={actual}.") + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) else: # compare failure check across ranks, ranks list should have two rank assert len(ranks) == 2 if expected != actual: - raise ValueError(f"ShardedTensor {prop_name} property does not match from different ranks! " - f"Found {prop_name}={expected} on rank:{ranks[0]}, " - f"and {prop_name}={actual} on rank:{ranks[1]}.") + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) def build_metadata_from_local_shards( local_shards: List[Shard], global_size: torch.Size, current_rank: int, - pg: c10d.ProcessGroup + pg: c10d.ProcessGroup, ) -> ShardedTensorMetadata: - assert len(local_shards) > 0, "must have local shards!" local_shard_metadatas: List[ShardMetadata] = [] @@ -121,21 +127,28 @@ def build_metadata_from_local_shards( local_shard_tensor = local_shard.tensor local_shard_meta = local_shard.metadata local_shard_metadatas.append(local_shard_meta) - rank, local_device = _parse_and_validate_remote_device(pg, local_shard_meta.placement) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) - if local_shard_tensor.layout != torch.strided or local_shard_tensor.layout != first_shard_layout: + if ( + local_shard_tensor.layout != torch.strided + or local_shard_tensor.layout != first_shard_layout + ): raise ValueError( - f'Only torch.strided layout is currently supported, but found ' - f'{local_shard_tensor.layout} on rank:{current_rank}!' + f"Only torch.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" ) if not local_shard_tensor.is_contiguous(): - raise ValueError('Only torch.contiguous_format memory_format is currently supported!') + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported!" + ) if rank != current_rank: raise ValueError( f"Local shard metadata's rank does not match with the rank in its process group! " - f'Found current rank in the process group: {current_rank}, ' + f"Found current rank in the process group: {current_rank}, " f"local ShardMetadata placement's rank: {rank}" ) if local_shard_tensor.device != local_device: @@ -145,10 +158,27 @@ def build_metadata_from_local_shards( f"local shard metadata placement device: {local_device}" ) - _raise_if_mismatch(local_shard_meta.shard_sizes, list(local_shard_tensor.size()), "size", current_rank) - _raise_if_mismatch(local_shard_tensor.is_pinned(), first_shard_is_pinned, "pin_memory", current_rank) - _raise_if_mismatch(local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank) - _raise_if_mismatch(local_shard_tensor.requires_grad, first_shard_requires_grad, "requires_grad", current_rank) + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then # do all_gather to collect local_sharded_tensor_metadata from all ranks @@ -157,18 +187,21 @@ def build_metadata_from_local_shards( layout=first_shard_layout, requires_grad=first_shard_requires_grad, memory_format=torch.contiguous_format, - pin_memory=first_shard_is_pinned + pin_memory=first_shard_is_pinned, ) local_sharded_tensor_metadata = ShardedTensorMetadata( shards_metadata=local_shard_metadatas, size=global_size, - tensor_properties=local_tensor_properties) + tensor_properties=local_tensor_properties, + ) return local_sharded_tensor_metadata -def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]]): +def build_global_metadata( + gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]], +): global_sharded_tensor_metadata = None global_metadata_rank = 0 @@ -180,39 +213,54 @@ def build_global_metadata(gathered_metadatas: Sequence[Optional[ShardedTensorMet global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) global_metadata_rank = rank else: - _raise_if_mismatch(global_sharded_tensor_metadata.size, - rank_metadata.size, - "global_size", - [global_metadata_rank, rank], - is_local=False) + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) # don't need to check layout and memory format as we already checked in local shards validation stage - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.dtype, - rank_metadata.tensor_properties.dtype, - "dtype", - [global_metadata_rank, rank], - is_local=False) - - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.requires_grad, - rank_metadata.tensor_properties.requires_grad, - "requires_grad", - [global_metadata_rank, rank], - is_local=False) - - _raise_if_mismatch(global_sharded_tensor_metadata.tensor_properties.pin_memory, - rank_metadata.tensor_properties.pin_memory, - "pin_memory", - [global_metadata_rank, rank], - is_local=False) + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) # pass all validations, extend shards metadata - global_sharded_tensor_metadata.shards_metadata.extend(rank_metadata.shards_metadata) + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) if global_sharded_tensor_metadata is not None: # check if shards_metadata have overlap shards - validate_non_overlapping_shards_metadata(global_sharded_tensor_metadata.shards_metadata) + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) # check if the shards_metadata is compatible with global size of the sharded tensor. - check_tensor(global_sharded_tensor_metadata.shards_metadata, global_sharded_tensor_metadata.size) + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) else: raise ValueError("ShardedTensor have no local shards on all ranks!") diff --git a/torch/distributed/_shard/sharder.py b/torch/distributed/_shard/sharder.py index bf3b3596d1beea..6fbf6a2e5ff9d7 100644 --- a/torch/distributed/_shard/sharder.py +++ b/torch/distributed/_shard/sharder.py @@ -1,6 +1,8 @@ import abc + import torch.nn as nn + class Sharder(abc.ABC): """ This is an interface which allows user to create more advanced @@ -11,6 +13,7 @@ class Sharder(abc.ABC): take an object of the `Sharder` and call `shard` to shard the module, then replace the original module with sharded module returned. """ + @abc.abstractmethod def shard(self, module: nn.Module) -> nn.Module: """ diff --git a/torch/distributed/_shard/sharding_plan/__init__.py b/torch/distributed/_shard/sharding_plan/__init__.py index 269dfd8af76052..325f7d7eb47b96 100644 --- a/torch/distributed/_shard/sharding_plan/__init__.py +++ b/torch/distributed/_shard/sharding_plan/__init__.py @@ -1,4 +1 @@ -from .api import ( - ShardingPlan, - ShardingPlanner -) +from .api import ShardingPlan, ShardingPlanner diff --git a/torch/distributed/_shard/sharding_plan/api.py b/torch/distributed/_shard/sharding_plan/api.py index fa92bf70788876..a7552c5a68f88e 100644 --- a/torch/distributed/_shard/sharding_plan/api.py +++ b/torch/distributed/_shard/sharding_plan/api.py @@ -1,12 +1,12 @@ import abc -import torch.nn as nn - from dataclasses import dataclass from typing import Dict, List, Optional, Union +import torch.nn as nn from torch.distributed._shard.sharder import Sharder from torch.distributed._shard.sharding_spec import ShardingSpec + @dataclass class ShardingPlan: """ @@ -61,6 +61,7 @@ class ShardingPlan: >>> return_local_tensor=["fc2"] >>> ) """ + plan: Dict[str, Union[ShardingSpec, Sharder]] output_plan: Optional[Dict[str, ShardingSpec]] = None return_local_tensor: Optional[List[str]] = None @@ -71,6 +72,7 @@ class ShardingPlanner(abc.ABC): Default ShardingPlanner interface, can be extended and implement advanced sharding strategies. """ + @abc.abstractmethod def build_plan(self, module: nn.Module) -> ShardingPlan: """ diff --git a/torch/distributed/_shard/sharding_spec/__init__.py b/torch/distributed/_shard/sharding_spec/__init__.py index 8dd38105c53ba4..bfd3f0a7581e8c 100644 --- a/torch/distributed/_shard/sharding_spec/__init__.py +++ b/torch/distributed/_shard/sharding_spec/__init__.py @@ -1,12 +1,10 @@ +from torch.distributed._shard.metadata import ShardMetadata + from .api import ( + _infer_sharding_spec_from_shards_metadata, DevicePlacementSpec, EnumerableShardingSpec, PlacementSpec, ShardingSpec, - _infer_sharding_spec_from_shards_metadata, ) -from .chunk_sharding_spec import ( - ChunkShardingSpec as ChunkShardingSpec, -) - -from torch.distributed._shard.metadata import ShardMetadata +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 07d3c2e19bc00a..8a439c447eff07 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -86,8 +86,8 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): for dim in range(len(shards[0].shard_offsets)): for i in range(1, len(shards)): if ( - shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] or - shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] ): sharded_dims.append(dim) break @@ -108,7 +108,7 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): pair = _find_nd_overlapping_shards(shards, sharded_dims) if pair: - raise ValueError(f'Shards {shards[pair[0]]} and {shards[pair[1]]} overlap') + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") def check_tensor(shards_metadata, tensor_dims) -> None: @@ -130,7 +130,9 @@ def check_tensor(shards_metadata, tensor_dims) -> None: tensor_rank = len(tensor_dims) shards_rank = len(shards_metadata[0].shard_offsets) if tensor_rank != shards_rank: - raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}') + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) total_shard_volume = 0 for shard in shards_metadata: @@ -139,8 +141,9 @@ def check_tensor(shards_metadata, tensor_dims) -> None: shard_volume *= shard_length if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: raise ValueError( - f'Shard offset {shard.shard_offsets[i]} and length ' - f'{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}') + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) total_shard_volume += shard_volume tensor_volume = 1 @@ -150,9 +153,11 @@ def check_tensor(shards_metadata, tensor_dims) -> None: if total_shard_volume != tensor_volume: # TODO: Can we improve this error message to point out the gaps? raise ValueError( - f'Total volume of shards: {total_shard_volume} ' - f'does not match tensor volume: {tensor_volume}, in other words ' - f'all the individual shards do not cover the entire tensor') + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + def get_split_size(dim_size, chunks): """ @@ -167,6 +172,7 @@ def get_split_size(dim_size, chunks): """ return (dim_size + chunks - 1) // chunks + def get_chunked_dim_size(dim_size, split_size, idx): """ Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` @@ -182,6 +188,7 @@ def get_chunked_dim_size(dim_size, split_size, idx): """ return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): """ Generate the start pos and offset length for the current rank for diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index 7493eccdf01581..e22e8b569e03ca 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -1,34 +1,36 @@ # mypy: allow-untyped-defs +import functools +import operator from abc import ABC, abstractmethod from dataclasses import dataclass -import functools from typing import Callable, Dict, List, TYPE_CHECKING import torch +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.op_registry_utils import _decorator_func from ._internals import ( check_tensor, get_chunked_dim_size, get_split_size, - validate_non_overlapping_shards_metadata + validate_non_overlapping_shards_metadata, ) -from torch.distributed._shard.metadata import ShardMetadata -import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta -from torch.distributed._shard.op_registry_utils import _decorator_func -import operator if TYPE_CHECKING: # Only include ShardedTensor when do type checking, exclude it # from run-time to resolve circular dependency. from torch.distributed._shard.sharded_tensor import ShardedTensor + class PlacementSpec(ABC): # noqa: B024 """ Base class representing the placement of an entity. Subclasses of this class can be used to specify customized placements which might not be covered by existing APIs. """ + pass @@ -47,15 +49,18 @@ def __post_init__(self): if not isinstance(self.device, torch.distributed._remote_device): self.device = torch.distributed._remote_device(self.device) + class ShardingSpec(ABC): """ Base class representing sharding specifications. """ + @abstractmethod - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: """ Given a global tensor size, define how to shard a tensor like this shape across ranks, return ShardedTensorMetadata @@ -71,7 +76,9 @@ def build_metadata(self, """ @abstractmethod - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": """ Given a global tensor on src_rank, shard this tensor across ranks within the process group, return a ShardedTensor. @@ -88,26 +95,35 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> A :class:`ShardedTensor` sharded from the given tensor. """ + # Ops customized for a particular ShardingSpec. _CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} + def _has_custom_op(sharding_spec, op): """ Returns whether or not the ShardingSpec has a custom op implementation. """ class_name = type(sharding_spec).__qualname__ - return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + -def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group): +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): """ Calls the custom op for this ShardingSpec if it exists. """ class_name = type(sharding_spec).__qualname__ if not _has_custom_op(sharding_spec, op): - raise RuntimeError(f'Custom op: {op} not registered for {class_name}') + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] return func(types, args, kwargs, process_group) + def custom_sharding_spec_op(sharding_spec_class, func): """ Decorator to allow custom registration of ops. @@ -119,9 +135,7 @@ def custom_sharding_spec_op(sharding_spec_class, func): if class_name not in _CUSTOM_SHARDING_SPEC_OPS: _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} return functools.partial( - _decorator_func, - op=func, - op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] ) @@ -140,30 +154,33 @@ class EnumerableShardingSpec(ShardingSpec): def __post_init__(self): if len(self.shards) == 0: - raise ValueError(f'Empty shard list provided: {self.shards}') + raise ValueError(f"Empty shard list provided: {self.shards}") # Validate each shard has same rank. rank = -1 for shard in self.shards: if rank != -1 and rank != len(shard.shard_offsets): - raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}') + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) rank = len(shard.shard_offsets) validate_non_overlapping_shards_metadata(self.shards) - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: # check if shards form a valid tensor check_tensor(self.shards, tensor_sizes) return sharded_tensor_meta.ShardedTensorMetadata( - self.shards, - tensor_sizes, - tensor_properties + self.shards, tensor_sizes, tensor_properties ) - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") @@ -216,10 +233,14 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata): if chunk_sharding_dim is not None: # Ensure we infer the correct placement order from offsets placements = [ - x for _, x in sorted(zip(chunk_offset_list, placements), key=operator.itemgetter(0)) + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) ] from .chunk_sharding_spec import ChunkShardingSpec + chunk_spec = ChunkShardingSpec( dim=chunk_sharding_dim, placements=placements, diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index bd2c960f7f60cf..dd0e354dfc25cc 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,28 +1,28 @@ # mypy: allow-untyped-defs from dataclasses import dataclass +from typing import cast, List, Optional, TYPE_CHECKING, Union + import torch +import torch.distributed as dist import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard._utils import narrow_tensor from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharded_tensor.utils import ( - _parse_and_validate_remote_device -) -from torch.distributed._shard._utils import narrow_tensor -import torch.distributed as dist -import torch.distributed.distributed_c10d as distributed_c10d -from typing import cast, List, Optional, Union, TYPE_CHECKING -from ._internals import ( - get_chunked_dim_size, - get_split_size, + _parse_and_validate_remote_device, ) +from ._internals import get_chunked_dim_size, get_split_size from .api import ShardingSpec + if TYPE_CHECKING: # Only include ShardedTensor when do type checking, exclude it # from run-time to resolve circular dependency. from torch.distributed._shard.sharded_tensor import ShardedTensor + @dataclass class ChunkShardingSpec(ShardingSpec): """ @@ -71,14 +71,13 @@ def _verify_dim(dim): ) if not isinstance(dim, int): - raise ValueError( - f"Sharding dim needs to be an integer, found: {dim}" - ) + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") - def build_metadata(self, - tensor_sizes: torch.Size, - tensor_properties: sharded_tensor_meta.TensorProperties, - ) -> sharded_tensor_meta.ShardedTensorMetadata: + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: tensor_num_dim = len(tensor_sizes) self._verify_dim(self.dim) @@ -105,13 +104,12 @@ def build_metadata(self, shards_metadata.append(shard_metadata) return sharded_tensor_meta.ShardedTensorMetadata( - shards_metadata, - tensor_sizes, - tensor_properties + shards_metadata, tensor_sizes, tensor_properties ) - - def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor": + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": """ Args: src_rank: group rank relative to ``process_group`` @@ -119,15 +117,14 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> N.B. If ``process_group`` is None, ``src_rank`` is a global rank. """ # relative imports to avoid circular dependency - from torch.distributed._shard.sharded_tensor import ( - ShardedTensor - ) + from torch.distributed._shard.sharded_tensor import ShardedTensor + tensor_properties = sharded_tensor_meta.TensorProperties( dtype=tensor.dtype, layout=tensor.layout, requires_grad=tensor.requires_grad, memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned() + pin_memory=tensor.is_pinned(), ) current_rank = dist.get_rank(process_group) current_global_rank = dist.get_rank() @@ -147,7 +144,9 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> scatter_shape[self.dim] = split_size # type: ignore[index] for shard_meta in tensor_meta.shards_metadata: - remote_global_rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) if current_rank == src_rank: # Reshape to get shard for this rank and we don't want autograd # recording here for the narrow op and 'local_shard' should be a @@ -158,7 +157,9 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> # resize the narrowed tensor to the same size and use it for # the scatter collective as dist.scatter requires same size # inputs on every rank - tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape) + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) else: tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() @@ -168,7 +169,11 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> if current_global_rank == remote_global_rank: local_tensor = torch.empty( - scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) local_metadata = shard_meta # each rank should have local_tensor and local_metadata initialized if we build @@ -179,14 +184,19 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> # Scatter the shards to all ranks in the pg # scatter takes the global rank as ``src`` src_for_scatter = src_rank - if process_group is not None and process_group is not distributed_c10d._get_default_group(): - src_for_scatter = distributed_c10d.get_global_rank(process_group, src_for_scatter) + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) dist.scatter( local_tensor, scatter_list=tensors_to_scatter if current_rank == src_rank else None, src=src_for_scatter, - group=process_group + group=process_group, ) if list(local_tensor.size()) != local_metadata.shard_sizes: @@ -199,9 +209,8 @@ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) st = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards, - tensor_meta, - process_group=process_group) + local_shards, tensor_meta, process_group=process_group + ) # Manually set sharding_spec st._sharding_spec = self diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index ce9984efac6e86..ab5136978f6681 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -13,12 +13,9 @@ import torch.distributed._functional_collectives import torch.nn as nn import torch.utils._pytree as pytree - from functorch import make_fx - from torch import fx from torch._decomp.decompositions import native_layer_norm_backward - from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._spmd.data_parallel import gradients_tagging from torch.distributed._spmd.parallel_mode import ( diff --git a/torch/distributed/_spmd/batch_dim_utils.py b/torch/distributed/_spmd/batch_dim_utils.py index d3c39295c0e660..244cc26c55ed49 100644 --- a/torch/distributed/_spmd/batch_dim_utils.py +++ b/torch/distributed/_spmd/batch_dim_utils.py @@ -2,17 +2,14 @@ from typing import Callable, Dict, List, Set import torch - import torch.fx as fx - import torch.utils._pytree as pytree - from torch import Tensor - from torch.distributed._tensor import DeviceMesh, Replicate, Shard from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim from torch.distributed._tensor.placement_types import _Partial, DTensorSpec + aten = torch.ops.aten diff --git a/torch/distributed/_spmd/config.py b/torch/distributed/_spmd/config.py index 73ee19e803dc84..3fc45bc27a3a1b 100644 --- a/torch/distributed/_spmd/config.py +++ b/torch/distributed/_spmd/config.py @@ -4,6 +4,7 @@ from types import ModuleType from typing import Set + # log level (levels print what it says + all levels listed below it) # DEBUG print full traces <-- lowest level + print tracing of every instruction # INFO print compiler functions + distributed graphs diff --git a/torch/distributed/_spmd/data_parallel.py b/torch/distributed/_spmd/data_parallel.py index 8b18c6c86763a1..835cdb9fa8efd5 100644 --- a/torch/distributed/_spmd/data_parallel.py +++ b/torch/distributed/_spmd/data_parallel.py @@ -2,17 +2,13 @@ import operator from contextlib import contextmanager from enum import Enum - from typing import Any, cast, Dict, List, Optional, Tuple import torch - import torch.fx as fx import torch.library import torch.nn as nn - import torch.utils._pytree as pytree - from torch.distributed._spmd.batch_dim_utils import BatchDimAnalyzer from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor._op_schema import ( @@ -22,7 +18,6 @@ TupleStrategy, ) from torch.distributed._tensor._redistribute import redistribute_local_tensor - from torch.distributed._tensor._utils import compute_local_shape from torch.distributed._tensor.placement_types import _Partial, DTensorSpec, Placement from torch.fx import GraphModule @@ -30,6 +25,7 @@ from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.nn.utils._named_member_accessor import NamedMemberAccessor + aten = torch.ops.aten # Dummy op used by data parallel to tag gradients. diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 5fb5ff766799ad..839b58bf03e039 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -9,11 +9,9 @@ import torch import torch.distributed._spmd.experimental_ops import torch.fx as fx - from torch.distributed._spmd.comm_tensor import _get_tracer from torch.distributed._spmd.graph_utils import OP from torch.distributed._spmd.log_utils import get_logger - from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor._op_schema import OpSchema from torch.distributed._tensor._redistribute import redistribute_local_tensor diff --git a/torch/distributed/_spmd/experimental_ops.py b/torch/distributed/_spmd/experimental_ops.py index 94a0da82244963..7039822c41ccdf 100644 --- a/torch/distributed/_spmd/experimental_ops.py +++ b/torch/distributed/_spmd/experimental_ops.py @@ -6,7 +6,6 @@ from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.common_rules import pointwise_rule from torch.distributed._tensor.ops.utils import register_prop_rule - from torch.distributed._tensor.placement_types import ( _Partial, DTensorSpec, @@ -16,6 +15,7 @@ TensorMeta, ) + aten = torch.ops.aten # pyre-ignore diff --git a/torch/distributed/_spmd/graph_optimization.py b/torch/distributed/_spmd/graph_optimization.py index 4a5cad7917d882..a50e266eb12826 100644 --- a/torch/distributed/_spmd/graph_optimization.py +++ b/torch/distributed/_spmd/graph_optimization.py @@ -37,6 +37,7 @@ from torch.utils import _pytree as pytree from torch.utils._pytree import tree_flatten, tree_unflatten + logger: logging.Logger = logging.getLogger("graph_optimization") aten = torch.ops.aten fake_tensor_mode = FakeTensorMode() diff --git a/torch/distributed/_spmd/parallel_mode.py b/torch/distributed/_spmd/parallel_mode.py index 2e9c15258d0926..65a55377ac82f2 100644 --- a/torch/distributed/_spmd/parallel_mode.py +++ b/torch/distributed/_spmd/parallel_mode.py @@ -11,7 +11,6 @@ ) from torch.distributed._spmd.distribute import _convert_to_distributed, Schema from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard - from torch.fx import GraphModule diff --git a/torch/distributed/_spmd/partial_lower.py b/torch/distributed/_spmd/partial_lower.py index bb1f1e2e085fc9..7899a8e143f622 100644 --- a/torch/distributed/_spmd/partial_lower.py +++ b/torch/distributed/_spmd/partial_lower.py @@ -7,12 +7,11 @@ from typing import Callable, List, Optional, Set, Tuple import torch - from functorch import make_fx - from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table + MIN_ATEN_OPS_TO_LOWER = 10 logger: logging.Logger = logging.getLogger(__name__) From 4817180601016f706ee0cce76b6d52b9cfc51ef5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 17 Jun 2024 23:06:28 +0000 Subject: [PATCH 132/171] make fallback for aten.argsort.stable (#128907) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128907 Approved by: https://github.com/lezcano ghstack dependencies: #128343 --- test/inductor/test_torchinductor_opinfo.py | 1 - torch/_inductor/lowering.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 5f97c2f0fd7121..29be591dc006c1 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -411,7 +411,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "_segment_reduce.lengths": {f16}, "_segment_reduce.offsets": {f16}, "addmv": {f16}, - "argsort": {b8, f16, f32, f64, i32, i64}, "as_strided.partial_views": {f16}, "corrcoef": {f16}, "diff": {f16}, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 449e512352fa1e..44a6d05d1f44d0 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2183,6 +2183,7 @@ def is_aligned(x): # Sorting / Sorting-like make_fallback(aten.sort) make_fallback(aten.sort.stable) +make_fallback(aten.argsort.stable) make_fallback(aten.kthvalue) make_fallback(aten.topk) make_fallback(aten.mode) From 108318ad1038f4f3ad0da4f54f53effdd9ef365a Mon Sep 17 00:00:00 2001 From: David Berard Date: Tue, 18 Jun 2024 15:40:45 +0000 Subject: [PATCH 133/171] [BE][JIT] Handle case where codegen object can be unset (#128951) Summary: Unblocks a test that's failing. `codegen` can be unset until `compile` is called. If `codegen` is not set, then just use the kernel name directly. Test Plan: ``` buck2 run //caffe2/test:tensorexpr -- --regex test_simple_add ``` Differential Revision: D58727391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128951 Approved by: https://github.com/aaronenyeshi --- torch/csrc/jit/tensorexpr/kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index d7c737d8f8f2c2..e5ea5bb46e0e79 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -181,7 +181,7 @@ class TORCH_API TensorExprKernel { } const std::string& getKernelName() const { - return codegen_->kernel_func_name(); + return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_); } const std::vector& getSymbolicShapeInputs() const { From ec616da51848bcfa9d0bd9c693c62b50fbe84c0f Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 16:16:38 +0000 Subject: [PATCH 134/171] RNN API cleanup for cuDNN 9.1 (#122011) Can potentially avoid a bit of boilerplate if we move directly to cuDNN 9.1's RNN API... Co-authored-by: Aaron Gokaslan Pull Request resolved: https://github.com/pytorch/pytorch/pull/122011 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cudnn/RNN.cpp | 32 +++++++++++------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 55c666eeca83c2..c90a6fd7a6c9cc 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -614,8 +614,6 @@ void add_projection_weights( /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( /*handle=*/handle, @@ -626,8 +624,8 @@ void add_projection_weights( /*linLayerID=*/linear_id, /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(), /*linLayerMat=*/&matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif cudnnDataType_t data_type; @@ -735,8 +733,6 @@ get_parameters( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; for (int stateless = 0; stateless < 100; stateless++) { if (cudnn_method) { // matrix @@ -749,8 +745,8 @@ get_parameters( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -759,8 +755,8 @@ get_parameters( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -922,8 +918,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer = nullptr; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; if (cudnn_method) { // matrix AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -935,8 +929,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); } else { // bias AT_CUDNN_CHECK(cudnnGetRNNWeightParams( handle, @@ -945,8 +939,8 @@ std::vector get_expected_data_ptrs( weight_buf.numel() * weight_buf.element_size(), weight_buf.data_ptr(), linear_id, - unused_desc.mut_desc(), - &unused_pointer, + nullptr, + nullptr, lin_layer_mat_desc.mut_desc(), &matrix_pointer)); } @@ -972,8 +966,6 @@ std::vector get_expected_data_ptrs( lin_layer_mat_desc.mut_desc(), &matrix_pointer)); #else - void* unused_pointer; - TensorDescriptor unused_desc; TensorDescriptor lin_layer_mat_desc; AT_CUDNN_CHECK(cudnnGetRNNWeightParams( @@ -985,8 +977,8 @@ std::vector get_expected_data_ptrs( linear_id, lin_layer_mat_desc.mut_desc(), &matrix_pointer, - unused_desc.mut_desc(), - &unused_pointer)); + nullptr, + nullptr)); #endif data_ptrs.push_back(matrix_pointer); } From 9818283da18de00047760ec4431870d3f8e620a6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 14 Jun 2024 19:12:10 +0000 Subject: [PATCH 135/171] re-enable jacrev/jacfwd/hessian after #128028 landed (#128622) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128622 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 69 ------------------- ...ion_no_setup_context_transform_hessian_cpu | 0 ...tion_no_setup_context_transform_jacfwd_cpu | 0 ...essianCPU.test_jacfwd_different_levels_cpu | 0 test/functorch/test_eager_transforms.py | 4 +- torch/_functorch/eager_transforms.py | 4 -- torch/testing/_internal/common_utils.py | 1 - 7 files changed, 2 insertions(+), 76 deletions(-) create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu create mode 100644 test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu create mode 100644 test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index dca6d28d1912dd..f2df33bdda67cf 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2746,26 +2746,6 @@ def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0): wrapped_gm = backend.graphs[graph_idx] return wrapped_gm - def test_hessian_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.hessian(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_hessian(self): counters.clear() @@ -2900,7 +2880,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_hessian_argnums(self): counters.clear() @@ -3046,7 +3025,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """ return (unflatten,)""", ) - @unittest.expectedFailure def test_hessian_disable_capture(self): counters.clear() @@ -3073,26 +3051,6 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) - def test_jacrev_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacrev(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacrev(self): counters.clear() @@ -3169,7 +3127,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacrev_two_tensors_argnums(self): counters.clear() @@ -3252,7 +3209,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_has_aux(self): counters.clear() @@ -3337,7 +3293,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacrev_disable_capture(self): counters.clear() @@ -4284,26 +4239,6 @@ def wrapper_fn(x, y): self.assertEqual(len(counters["graph_break"]), 0) self.assertEqual(actual, expected) - def test_jacfwd_graph_break(self): - counters.clear() - - def wrapper_fn(x): - return torch.func.jacfwd(torch.sin)(x) - - x = torch.randn(4, 3) - expected = wrapper_fn(x) - got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x) - self.assertEqual(expected, got) - self.assertEqual(len(counters["graph_break"]), 2) - self.assertEqual( - { - "'skip function disable in file _dynamo/decorators.py'": 1, - "call torch._dynamo.disable() wrapped function .wrapper_fn at 0xN>": 1, - }, - {munge_exc(k): v for k, v in counters["graph_break"].items()}, - ) - - @unittest.expectedFailure def test_jacfwd(self): counters.clear() @@ -4387,7 +4322,6 @@ def forward(self, L_x_: "f32[4, 3]"): """, ) - @unittest.expectedFailure def test_jacfwd_two_tensors_argnums(self): counters.clear() @@ -4477,7 +4411,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_has_aux(self): counters.clear() @@ -4572,7 +4505,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_randomness(self): counters.clear() @@ -4676,7 +4608,6 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): """, ) - @unittest.expectedFailure def test_jacfwd_disable_capture(self): counters.clear() diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_hessian_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_autograd_function_no_setup_context_transform_jacfwd_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu b/test/dynamo_expected_failures/TestHessianCPU.test_jacfwd_different_levels_cpu new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 8107f865f7bc54..c767810beb85a5 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -77,6 +77,7 @@ subtest, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -2341,8 +2342,7 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - # it won't fail as jacrev/jacfwd were not inlined (see #128255) - # @xfailIfTorchDynamo + @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fbea5164014bcd..fff6bd67838f01 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -767,8 +767,6 @@ def compute_jacobian_preallocate_and_copy(): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn @@ -1350,8 +1348,6 @@ def push_jvp(basis): # wraps only if we're not tracing with dynamo. if not torch._dynamo.is_compiling(): wrapper_fn = wraps(func)(wrapper_fn) - else: - wrapper_fn = torch._dynamo.disable(wrapper_fn) return wrapper_fn diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 2d5ea4a6c64ffd..8daeefdee9d855 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -5008,7 +5008,6 @@ def repl_frame(m): return m.group(0) s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s) - s = re.sub(r'( Date: Tue, 18 Jun 2024 17:15:05 +0000 Subject: [PATCH 136/171] [EZ] Fix typos in RELEASE.md (#128769) This PR fixes typo in `RELEASE.md` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128769 Approved by: https://github.com/yumium, https://github.com/mikaylagawarecki --- RELEASE.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 3c9d68f9a6cdcc..7091052c85bd10 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -290,7 +290,7 @@ After the final RC is created. The following tasks should be performed : * Create validation issue for the release, see for example [Validations for 2.1.2 release](https://github.com/pytorch/pytorch/issues/114904) and perform required validations. -* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no prerformance regressions. +* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no performance regressions. * Prepare and stage PyPI binaries for promotion. This is done with this script: [`pytorch/builder:release/pypi/promote_pypi_to_staging.sh`](https://github.com/pytorch/builder/blob/main/release/pypi/promote_pypi_to_staging.sh) @@ -429,12 +429,12 @@ need to support these particular versions of software. ## Operating Systems Supported OS flavors are summarized in the table below: -| Operating System family | Architectrue | Notes | +| Operating System family | Architecture | Notes | | --- | --- | --- | | Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. | | MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). | | MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 | -| Windows | x86_64 | Buils are compatible with Windows-10 or newer. | +| Windows | x86_64 | Builds are compatible with Windows-10 or newer. | # Submitting Tutorials From 4e03263224af813fbf5e0e745e84c13268c48dc7 Mon Sep 17 00:00:00 2001 From: eqy Date: Tue, 18 Jun 2024 17:26:23 +0000 Subject: [PATCH 137/171] [CUDA][Convolution] Add missing launch bounds to `vol2col_kernel` (#128740) Fix "too many resources requested" that can happen with recent toolkits on V100. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128740 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/cuda/vol2col.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh index 98ec2c3522d541..222270e8621606 100644 --- a/aten/src/ATen/native/cuda/vol2col.cuh +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -14,6 +14,7 @@ using namespace at::cuda::detail; // Kernel for fast unfold+copy on volumes template +C10_LAUNCH_BOUNDS_1(1024) __global__ void vol2col_kernel( const int64_t n, const T* data_vol, From 84c86e56bd8b86ae47c18b77141c1fe46188c5b7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 18 Jun 2024 17:48:47 +0000 Subject: [PATCH 138/171] Update tracker issues after successfully cherry-picking a PR (#128924) This extends the capacity of the cherry-pick bot to automatically update the tracker issue with the information. For this to work, the tracker issue needs to be an open one with a `release tracker` label, i.e. https://github.com/pytorch/pytorch/issues/128436. The version from the release branch, i.e. `release/2.4`, will be match with the title of the tracker issue, i.e. `[v.2.4.0] Release Tracker` or `[v.2.4.1] Release Tracker` ### Testing `python cherry_pick.py --onto-branch release/2.4 --classification release --fixes "DEBUG DEBUG" --github-actor huydhn 128718` * On the PR https://github.com/pytorch/pytorch/pull/128718#issuecomment-2174846771 * On the tracker issue https://github.com/pytorch/pytorch/issues/128436#issuecomment-2174846757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128924 Approved by: https://github.com/atalman --- .github/scripts/cherry_pick.py | 114 ++++++++++++++++++++++++++++---- .github/scripts/github_utils.py | 9 +++ 2 files changed, 111 insertions(+), 12 deletions(-) diff --git a/.github/scripts/cherry_pick.py b/.github/scripts/cherry_pick.py index 4c892de21da8af..2650a5060d0ff7 100755 --- a/.github/scripts/cherry_pick.py +++ b/.github/scripts/cherry_pick.py @@ -3,11 +3,11 @@ import json import os import re -from typing import Any, Optional +from typing import Any, cast, Dict, List, Optional from urllib.error import HTTPError -from github_utils import gh_fetch_url, gh_post_pr_comment +from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo from trymerge import get_pr_commit_sha, GitHubPR @@ -19,6 +19,7 @@ "critical", "fixnewfeature", } +RELEASE_BRANCH_REGEX = re.compile(r"release/(?P.+)") def parse_args() -> Any: @@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]: return commit_sha if pr.is_closed() else None +def get_release_version(onto_branch: str) -> Optional[str]: + """ + Return the release version if the target branch is a release branch + """ + m = re.match(RELEASE_BRANCH_REGEX, onto_branch) + return m.group("version") if m else "" + + +def get_tracker_issues( + org: str, project: str, onto_branch: str +) -> List[Dict[str, Any]]: + """ + Find the tracker issue from the repo. The tracker issue needs to have the title + like [VERSION] Release Tracker following the convention on PyTorch + """ + version = get_release_version(onto_branch) + if not version: + return [] + + tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"]) + if not tracker_issues: + return [] + + # Figure out the tracker issue from the list by looking at the title + return [issue for issue in tracker_issues if version in issue.get("title", "")] + + def cherry_pick( github_actor: str, repo: GitRepo, @@ -77,17 +105,49 @@ def cherry_pick( ) try: + org, project = repo.gh_owner_and_name() + + cherry_pick_pr = "" if not dry_run: - org, project = repo.gh_owner_and_name() cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch) - msg = f"The cherry pick PR is at {cherry_pick_pr}" - if fixes: - msg += f" and it is linked with issue {fixes}" - elif classification in REQUIRES_ISSUE: - msg += f" and it is recommended to link a {classification} cherry pick PR with an issue" + tracker_issues_comments = [] + tracker_issues = get_tracker_issues(org, project, onto_branch) + for issue in tracker_issues: + issue_number = int(str(issue.get("number", "0"))) + if not issue_number: + continue + + res = cast( + Dict[str, Any], + post_tracker_issue_comment( + org, + project, + issue_number, + pr.pr_num, + cherry_pick_pr, + classification, + fixes, + dry_run, + ), + ) + + comment_url = res.get("html_url", "") + if comment_url: + tracker_issues_comments.append(comment_url) - post_comment(org, project, pr.pr_num, msg) + msg = f"The cherry pick PR is at {cherry_pick_pr}" + if fixes: + msg += f" and it is linked with issue {fixes}." + elif classification in REQUIRES_ISSUE: + msg += f" and it is recommended to link a {classification} cherry pick PR with an issue." + + if tracker_issues_comments: + msg += " The following tracker issues are updated:\n" + for tracker_issues_comment in tracker_issues_comments: + msg += f"* {tracker_issues_comment}\n" + + post_pr_comment(org, project, pr.pr_num, msg, dry_run) finally: if current_branch: @@ -159,7 +219,9 @@ def submit_pr( raise RuntimeError(msg) from error -def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: +def post_pr_comment( + org: str, project: str, pr_num: int, msg: str, dry_run: bool = False +) -> List[Dict[str, Any]]: """ Post a comment on the PR itself to point to the cherry picking PR when success or print the error when failure @@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None: comment = "\n".join( (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}") ) - gh_post_pr_comment(org, project, pr_num, comment) + return gh_post_pr_comment(org, project, pr_num, comment, dry_run) + + +def post_tracker_issue_comment( + org: str, + project: str, + issue_num: int, + pr_num: int, + cherry_pick_pr: str, + classification: str, + fixes: str, + dry_run: bool = False, +) -> List[Dict[str, Any]]: + """ + Post a comment on the tracker issue (if any) to record the cherry pick + """ + comment = "\n".join( + ( + "Link to landed trunk PR (if applicable):", + f"* https://github.com/{org}/{project}/pull/{pr_num}", + "", + "Link to release branch PR:", + f"* {cherry_pick_pr}", + "", + "Criteria Category:", + " - ".join((classification.capitalize(), fixes.capitalize())), + ) + ) + return gh_post_pr_comment(org, project, issue_num, comment, dry_run) def main() -> None: @@ -214,7 +304,7 @@ def main() -> None: except RuntimeError as error: if not args.dry_run: - post_comment(org, project, pr_num, str(error)) + post_pr_comment(org, project, pr_num, str(error)) else: raise error diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index d76d32f624d8a9..f804c6e197dd46 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") -> ) else: raise + + +def gh_query_issues_by_labels( + org: str, repo: str, labels: List[str], state: str = "open" +) -> List[Dict[str, Any]]: + url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues" + return gh_fetch_json( + url, method="GET", params={"labels": ",".join(labels), "state": state} + ) From 77830d509fcae41be37f5b3a2fa05faabc778e29 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 18:11:43 +0000 Subject: [PATCH 139/171] Revert "Introduce a prototype for SymmetricMemory (#128582)" This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1. Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232)) --- .lintrunner.toml | 1 - BUILD.bazel | 1 - build_variables.bzl | 2 - c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 - test/distributed/test_symmetric_memory.py | 156 ----- torch/_C/_distributed_c10d.pyi | 30 - .../distributed/c10d/CUDASymmetricMemory.cu | 539 ------------------ .../distributed/c10d/CUDASymmetricMemory.cuh | 109 ---- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 - .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ------ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 ----- torch/csrc/distributed/c10d/init.cpp | 39 -- .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +++- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 111 insertions(+), 1254 deletions(-) delete mode 100644 test/distributed/test_symmetric_memory.py delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index dc9f9ddd46c7ce..a7bbdc884415ee 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,7 +68,6 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', - 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index c563c52d861e67..10c065f5084c7e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,7 +744,6 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 793b611a0a6f07..ceb28707897e56 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,7 +501,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", - "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -685,7 +684,6 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index cbbdf16823ec76..43bcbd1d70bace 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,17 +18,14 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8426741609fe7f..89c31fab113473 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,7 +560,6 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py deleted file mode 100644 index a768e059044f79..00000000000000 --- a/test/distributed/test_symmetric_memory.py +++ /dev/null @@ -1,156 +0,0 @@ -# Owner(s): ["module: c10d"] - -import torch - -import torch.distributed as dist -from torch._C._distributed_c10d import _SymmetricMemory -from torch.distributed.distributed_c10d import _get_process_group_store - -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - run_tests, - skip_but_pass_in_sandcastle_if, - skipIfRocm, -) - - -def requires_cuda_p2p_access(): - cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 - ) - num_devices = torch.cuda.device_count() - for i in range(num_devices - 1): - for j in range(i + 1, num_devices): - if not torch.cuda.can_device_access_peer(i, j): - cuda_p2p_access_available = False - break - if not cuda_p2p_access_available: - break - - return skip_but_pass_in_sandcastle_if( - not cuda_p2p_access_available, - "cuda p2p access is not available", - ) - - -@instantiate_parametrized_tests -@requires_cuda_p2p_access() -class SymmetricMemoryTest(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - @property - def world_size(self) -> int: - return 2 - - @property - def device(self) -> torch.device: - return torch.device(f"cuda:{self.rank}") - - def _init_process(self): - torch.cuda.set_device(self.device) - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - _SymmetricMemory.set_group_info( - "0", - self.rank, - self.world_size, - _get_process_group_store(dist.GroupMember.WORLD), - ) - - def _verify_symmetric_memory(self, symm_mem): - self.assertEqual(symm_mem.world_size, 2) - - buf = symm_mem.get_buffer(0, (64, 64), torch.float32) - if symm_mem.rank == 0: - symm_mem.wait_signal(src_rank=1) - self.assertTrue(buf.eq(42).all()) - else: - buf.fill_(42) - symm_mem.put_signal(dst_rank=0) - - symm_mem.barrier() - - if symm_mem.rank == 0: - symm_mem.barrier() - self.assertTrue(buf.eq(43).all()) - else: - buf.fill_(43) - symm_mem.barrier() - - symm_mem.barrier() - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name) - - t = torch.empty(shape, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - _SymmetricMemory.rendezvous(t) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) - - del t - self._verify_symmetric_memory(symm_mem) - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p_persistent(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - alloc_id = 42 # Persistent allocation - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name, alloc_id) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - data_ptr = t.data_ptr() - - # Verify that persistent allocation would fail if there's an active - # allocation with the same alloc_id. - with self.assertRaises(RuntimeError): - _SymmetricMemory.empty_strided_p2p(*alloc_args) - - # Verify that persistent allocation would succeed in lieu of activate - # allocations with the same alloc_id, and the returned tensor would - # have the same data pointer. - del t - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - self.assertEqual(t.data_ptr(), data_ptr) - - # Verify that get_symmetric_memory would fail if called before - # rendezvous. - with self.assertRaises(RuntimeError): - _SymmetricMemory.get_symmetric_memory(t) - - symm_mem_0 = _SymmetricMemory.rendezvous(t) - symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) - self.assertEqual(id(symm_mem_0), id(symm_mem_1)) - - self._verify_symmetric_memory(symm_mem_0) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 0095b5af434b5c..cffbf22219c8e7 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,33 +637,3 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... - -class _SymmetricMemory: - @staticmethod - def set_group_info( - group_name: str, rank: int, world_size: int, store: Store - ) -> None: ... - @staticmethod - def empty_strided_p2p( - size: torch.types._size, - stride: torch.types._size, - dtype: torch.dtype, - device: torch.device, - group_name: str, - ) -> torch.Tensor: ... - @property - def rank(self) -> int: ... - @property - def world_size(self) -> int: ... - @staticmethod - def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... - def get_buffer( - self, - rank: int, - sizes: torch.Size, - dtype: torch.dtype, - storage_offset: Optional[int] = 0, - ) -> torch.Tensor: ... - def barrier(self, channel: int = 0) -> None: ... - def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... - def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu deleted file mode 100644 index f27db85f7ff85d..00000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ /dev/null @@ -1,539 +0,0 @@ -#include - -#include -#include -#include -#include - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include -#endif - -#include -#include - -namespace { - -constexpr size_t signal_pad_size = 2048; -const std::string store_comm_prefix = "CUDASymmetricMemory"; - -static size_t store_comm_seq_id = 0; - -template -std::vector store_all_gather( - const c10::intrusive_ptr& store, - int rank, - int world_size, - T val) { - static_assert(std::is_trivially_copyable_v); - - std::vector peer_keys; - for (int r = 0; r < world_size; ++r) { - std::ostringstream oss; - oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; - peer_keys.push_back(oss.str()); - } - ++store_comm_seq_id; - - { - std::vector payload( - reinterpret_cast(&val), - reinterpret_cast(&val) + sizeof(T)); - store->set(peer_keys[rank], payload); - } - - std::vector peer_vals; - for (int r = 0; r < world_size; ++r) { - if (r == rank) { - peer_vals.push_back(val); - continue; - } - store->wait({peer_keys[r]}); - auto payload = store->get(peer_keys[r]); - TORCH_CHECK(payload.size() == sizeof(T)); - T peer_val{}; - std::memcpy(&peer_val, payload.data(), sizeof(T)); - peer_vals.push_back(peer_val); - } - return peer_vals; -} - -void store_barrier( - const c10::intrusive_ptr& store, - int rank, - int world_size) { - store_all_gather(store, rank, world_size, 0); -} - -int import_remote_fd(int pid, int fd) { -#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) - int pidfd = syscall(SYS_pidfd_open, pid, 0); - return syscall(SYS_pidfd_getfd, pidfd, fd, 0); -#else - TORCH_CHECK( - false, - "CUDASymmetricMemory requires pidfd_open ", - "and pidfd_getfd support"); -#endif -} - -void map_block( - void** ptr, - c10d::symmetric_memory::HandleType handle, - size_t size, - int device_idx) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - auto dev_ptr = reinterpret_cast(ptr); - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); - - CUmemAccessDesc desc; - desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - desc.location.id = static_cast(device_idx); - desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -CUDASymmetricMemory::CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size) - : handles_(std::move(handles)), - block_size_(block_size), - buffers_(std::move(buffers)), - signal_pads_(std::move(signal_pads)), - buffer_size_(buffer_size), - local_device_idx_(local_device_idx), - rank_(rank), - world_size_(world_size) { - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - - c10::cuda::CUDAGuard guard(local_device_idx); - AT_CUDA_CHECK(cudaMemcpy( - buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); - AT_CUDA_CHECK(cudaMemcpy( - signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); -} - -CUDASymmetricMemory::~CUDASymmetricMemory() { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - c10::cuda::CUDAGuard guard(local_device_idx_); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - auto driver_api = c10::cuda::DriverAPI::get(); - for (int r = 0; r < world_size_; ++r) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(buffers_[r]), block_size_)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); - } - c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); - c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -std::vector CUDASymmetricMemory::get_buffer_ptrs() { - return buffers_; -} - -std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { - return signal_pads_; -} - -void** CUDASymmetricMemory::get_buffer_ptrs_dev() { - return buffers_dev_; -} - -void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { - return signal_pads_dev_; -} - -size_t CUDASymmetricMemory::get_buffer_size() { - return buffer_size_; -} - -size_t CUDASymmetricMemory::get_signal_pad_size() { - return signal_pad_size; -} - -at::Tensor CUDASymmetricMemory::get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - const auto numel = - std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "CUDASymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storage_offset) - .options(options) - .target_device(device) - .make_tensor(); -} - -void check_channel(int channel, int world_size) { - TORCH_CHECK( - channel >= 0, - "channel for barrier(), put_signal() and wait_signal() ", - "must be greater than 0 (got ", - channel, - ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; - TORCH_CHECK( - static_cast(channel) < num_channels, - "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", - num_channels - 1, - " (got ", - channel, - ")"); -} - -__device__ __forceinline__ void release_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); -#endif -} - -__device__ __forceinline__ void acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); -#endif -} - -static __global__ void barrier_kernel( - uint32_t** signal_pads, - int channel, - int rank, - int world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + world_size * channel + rank); - acquire_signal(signal_pads[rank] + world_size * channel + target_rank); - } -} - -void CUDASymmetricMemory::barrier(int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void put_signal_kernel( - uint32_t** signal_pads, - int dst_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - release_signal(signal_pads[dst_rank] + world_size * channel + rank); - } -} - -void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - dst_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void wait_signal_kernel( - uint32_t** signal_pads, - int src_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - acquire_signal(signal_pads[rank] + world_size * channel + src_rank); - } - __threadfence_system(); -} - -void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - src_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -int CUDASymmetricMemory::get_rank() { - return rank_; -} - -int CUDASymmetricMemory::get_world_size() { - return world_size_; -} - -void* CUDASymmetricMemoryAllocator::alloc( - size_t size, - int device_idx, - const std::string& group_name) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - prop.location.id = device_idx; - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; - - size_t granularity; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - block_size = at::round_up(block_size, granularity); - - HandleType handle; - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); - - void* ptr = nullptr; - map_block(&ptr, handle, block_size, device_idx); - - c10::cuda::CUDAGuard guard(device_idx); - AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); - - auto block = c10::make_intrusive( - handle, device_idx, block_size, size, signal_pad_offset, group_name); - { - std::unique_lock lock(mutex_); - ptr_to_block_.emplace(ptr, std::move(block)); - } - return ptr; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -void CUDASymmetricMemoryAllocator::free(void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - if (block == nullptr) { - return; - } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. - if (block->symm_mem == nullptr) { - auto driver_api = c10::cuda::DriverAPI::get(); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(ptr), block->block_size)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); - } - { - std::unique_lock lock(mutex_); - ptr_to_block_.erase(ptr); - } -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->buffer_size; -} - -struct RendezvousRequest { - int device_idx; - int block_fd; - int pid; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; -}; - -void validate_rendezvous_requests( - const std::vector reqs, - int world_size) { - TORCH_CHECK(reqs.size() == (size_t)world_size); - - std::unordered_set device_indices; - device_indices.reserve(world_size); - for (auto req : reqs) { - device_indices.insert(req.device_idx); - } - if (device_indices.size() < (size_t)world_size) { - TORCH_CHECK( - false, - "CUDASymmetricMemoryAllocator::rendezvous: ", - "detected allocations from overlapping devices ", - "from different ranks."); - } - - for (int r = 1; r < world_size; ++r) { - TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); - TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); - TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); - } -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( - void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - - if (block->symm_mem != nullptr) { - return block->symm_mem; - } - - auto group_info = get_group_info(block->group_name); - auto driver_api = c10::cuda::DriverAPI::get(); - int block_fd; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); - - auto local_req = RendezvousRequest{ - .device_idx = block->device_idx, - .block_fd = block_fd, - .pid = getpid(), - .block_size = block->block_size, - .buffer_size = block->buffer_size, - .signal_pad_offset = block->signal_pad_offset}; - auto reqs = store_all_gather( - group_info.store, group_info.rank, group_info.world_size, local_req); - validate_rendezvous_requests(reqs, group_info.world_size); - - std::vector handles(group_info.world_size); - std::vector buffers(group_info.world_size, nullptr); - std::vector signal_pads(group_info.world_size, nullptr); - for (int r = 0; r < group_info.world_size; ++r) { - if (r == group_info.rank) { - handles[r] = block->handle; - buffers[r] = ptr; - signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); - continue; - } - int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)(uintptr_t)imported_fd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - map_block(&buffers[r], handles[r], block->block_size, block->device_idx); - signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); - close(imported_fd); - } - store_barrier(group_info.store, group_info.rank, group_info.world_size); - close(block_fd); - - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. So that outstanding - // references to the CUDASymmetricMemory object can keep the allocation - // alive. - block->symm_mem = c10::make_intrusive( - std::move(handles), - block->block_size, - std::move(buffers), - std::move(signal_pads), - block->buffer_size, - block->device_idx, - group_info.rank, - group_info.world_size); - return block->symm_mem; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->symm_mem != nullptr; -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { - std::shared_lock lock(mutex_); - auto it = ptr_to_block_.find(ptr); - if (it == ptr_to_block_.end()) { - return nullptr; - } - return it->second; -} - -struct RegisterCUDASymmetricMemoryAllocator { - RegisterCUDASymmetricMemoryAllocator() { - register_allocator( - c10::DeviceType::CUDA, - c10::make_intrusive()); - } -}; - -static RegisterCUDASymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh deleted file mode 100644 index 0e0e40a6bd0910..00000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10d { -namespace symmetric_memory { - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -using HandleType = CUmemGenericAllocationHandle; -#else -using HandleType = void*; -#endif - -class CUDASymmetricMemory : public SymmetricMemory { - public: - CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size); - - ~CUDASymmetricMemory() override; - - std::vector get_buffer_ptrs() override; - std::vector get_signal_pad_ptrs() override; - void** get_buffer_ptrs_dev() override; - void** get_signal_pad_ptrs_dev() override; - size_t get_buffer_size() override; - size_t get_signal_pad_size() override; - - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) override; - - void barrier(int channel) override; - void put_signal(int dst_rank, int channel) override; - void wait_signal(int src_rank, int channel) override; - - int get_rank() override; - int get_world_size() override; - - private: - std::vector handles_; - size_t block_size_; - std::vector buffers_; - std::vector signal_pads_; - size_t buffer_size_; - int local_device_idx_; - int rank_; - int world_size_; - void** buffers_dev_; - void** signal_pads_dev_; - std::optional> finalizer_; -}; - -struct Block : public c10::intrusive_ptr_target { - HandleType handle; - int device_idx; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; - std::string group_name; - c10::intrusive_ptr symm_mem = nullptr; - - Block( - HandleType handle, - int device_idx, - size_t block_size, - size_t buffer_size, - size_t signal_pad_offset, - const std::string& group_name) - : handle(handle), - device_idx(device_idx), - block_size(block_size), - buffer_size(buffer_size), - signal_pad_offset(signal_pad_offset), - group_name(group_name), - symm_mem(nullptr) {} -}; - -class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::string& group_name) override; - - void free(void *ptr) override; - size_t get_alloc_size(void* ptr) override; - c10::intrusive_ptr rendezvous(void* ptr) override; - bool is_rendezvous_completed(void* ptr) override; - - private: - c10::intrusive_ptr find_block(void* ptr); - - std::shared_mutex mutex_; - std::unordered_map> ptr_to_block_; -}; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index 7c41414c4e4e17..cff4ad09b70648 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,7 +10,6 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp deleted file mode 100644 index b3d9f31bb03420..00000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include - -namespace { - -using namespace c10d::symmetric_memory; - -class AllocatorMap { - public: - static AllocatorMap& get() { - static AllocatorMap instance; - return instance; - } - - void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - map_[device_type] = std::move(allocator); - } - - c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - auto it = map_.find(device_type); - TORCH_CHECK( - it != map_.end(), - "SymmetricMemory does not support device type ", - device_type); - return it->second; - } - - ~AllocatorMap() { - for (auto& it : map_) { - it.second.release(); - } - } - - private: - AllocatorMap() = default; - AllocatorMap(const AllocatorMap&) = delete; - AllocatorMap& operator=(const AllocatorMap&) = delete; - - std::unordered_map< - c10::DeviceType, - c10::intrusive_ptr> - map_; -}; - -static std::unordered_map group_info_map{}; - -// Data structures for tracking persistent allocations -static std::unordered_map alloc_id_to_dev_ptr{}; -static std::unordered_map> - alloc_id_to_storage{}; - -static at::Tensor empty_strided_p2p_persistent( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - uint64_t alloc_id) { - // Make the allocation fails if a previous allocation with the same alloc_id - // is still active. - auto storage = alloc_id_to_storage.find(alloc_id); - if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { - TORCH_CHECK( - false, - "SymmetricMemory::empty_strided_p2p_persistent: ", - "can not allocate with alloc_id == ", - alloc_id, - " because a previous allocation with the same alloc_id " - "is still active."); - } - - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = nullptr; - if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { - dev_ptr = alloc_id_to_dev_ptr[alloc_id]; - TORCH_CHECK( - alloc_size == allocator->get_alloc_size(dev_ptr), - "SymmetricMemory::empty_strided_p2p_persistent: ", - "requested allocation size (", - alloc_size, - ") is different from the size of a previous allocation ", - "with the same alloc_id ", - allocator->get_alloc_size(dev_ptr)); - } else { - dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - alloc_id_to_dev_ptr[alloc_id] = dev_ptr; - } - - auto options = at::TensorOptions().dtype(dtype).device(device); - auto allocated = at::from_blob(dev_ptr, size, stride, options); - - // Track the allocation's activeness - alloc_id_to_storage.erase(alloc_id); - alloc_id_to_storage.emplace( - alloc_id, allocated.storage().getWeakStorageImpl()); - return allocated; -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - return AllocatorMap::get().register_allocator( - device_type, std::move(allocator)); -} - -c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - return AllocatorMap::get().get_allocator(device_type); -} - -void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store) { - TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); - GroupInfo group_info; - group_info.rank = rank; - group_info.world_size = world_size; - group_info.store = std::move(store); - group_info_map.emplace(group_name, std::move(group_info)); -} - -const GroupInfo& get_group_info(const std::string& group_name) { - TORCH_CHECK( - group_info_map.find(group_name) != group_info_map.end(), - "get_group_info: no group info associated with the group name ", - group_name); - return group_info_map[group_name]; -} - -at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id) { - if (alloc_id.has_value()) { - return empty_strided_p2p_persistent( - size, stride, dtype, device, group_name, *alloc_id); - } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::from_blob( - dev_ptr, - size, - stride, - [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, - options); -} - -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.data_ptr()); -} - -c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - TORCH_CHECK( - allocator->is_rendezvous_completed(tensor.data_ptr()), - "SymmetricMemory: must invoke rendezvous on a tensor ", - "before calling get_symmetric_memory on it"); - return allocator->rendezvous(tensor.data_ptr()); -} - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp deleted file mode 100644 index 344b86ea5c7e3a..00000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once - -#include -#include - -namespace c10d { -namespace symmetric_memory { - -// SymmetricMemory represents symmetric allocations across a group of devices. -// The allocations represented by a SymmetricMemory object are accessible by -// all devices in the group. The class can be used for op-level custom -// communication patterns (via the get_buffer APIs and the synchronization -// primitives), as well as custom communication kernels (via the buffer and -// signal_pad device pointers). -// -// To acquire a SymmetricMemory object, each rank first allocates -// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes -// SymmetricMemoryAllocator::rendezvous() on the memory to establish the -// association across peer buffers. The rendezvous is a one-time process, and -// the mapping between a local memory memory and the associated SymmetricMemory -// object is unique. -// -// NOTE [symmetric memory signal pad] -// Signal pads are P2P-accessible memory regions designated for -// synchronization. SymmetricMemory offers built-in synchronization primitives -// such as barriers, put_signal, and wait_signal, which are all based on signal -// pads. Users may utilize signal pads for their own synchronization logic, -// provided that the signal pads remain zero-filled following successful -// synchronization. -// -// NOTE [symmetric memory synchronization channel] -// Synchronization channels allow users to use a single SymmetricMemory object -// to perform isolated synchronizations on different streams. For example, -// consider the case in which two barriers are issued on two streams for -// different purposes. Without the concept of channels, we cannot guarantee the -// correctness of the barriers since signals issued from barrier on stream A -// can be received by the barrier on stream B. By specifying different channels -// for these two barriers, they can operate correctly in parallel. -class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemory() {} - - virtual std::vector get_buffer_ptrs() = 0; - virtual std::vector get_signal_pad_ptrs() = 0; - - // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer - // to a device array of size world_size, containing buffer pointers and - // signal pad pointers, respectively. - virtual void** get_buffer_ptrs_dev() = 0; - virtual void** get_signal_pad_ptrs_dev() = 0; - virtual size_t get_buffer_size() = 0; - virtual size_t get_signal_pad_size() = 0; - - virtual at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) = 0; - - virtual void barrier(int channel) = 0; - virtual void put_signal(int dst_rank, int channel) = 0; - virtual void wait_signal(int src_rank, int channel) = 0; - - virtual int get_rank() = 0; - virtual int get_world_size() = 0; -}; - -class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemoryAllocator(){}; - - virtual void* alloc( - size_t size, - int device_idx, - const std::string& group_name) = 0; - - virtual void free(void* ptr) = 0; - virtual size_t get_alloc_size(void* ptr) = 0; - virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; - virtual bool is_rendezvous_completed(void* ptr) = 0; -}; - -C10_EXPORT void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator); - -C10_EXPORT c10::intrusive_ptr get_allocator( - c10::DeviceType device_type); - -// Set a store for rendezvousing symmetric allocations on a group of devices -// identified by `group_name`. The concept of groups is logical; users can -// utilize predefined groups (e.g., a group of device identified by a -// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator -// backends might employ a more efficient communication channel for the actual -// rendezvous process and only use the store for bootstrapping purposes. -TORCH_API void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store); - -struct GroupInfo { - int rank; - int world_size; - c10::intrusive_ptr store; -}; - -C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); - -// Identical to empty_strided, but allows symmetric memory access to be -// established for the allocated tensor via SymmetricMemory::rendezvous(). This -// function itself is not a collective operation. It invokes -// SymmetricMemoryAllocator::alloc() for the requested device under the hood. -// -// NOTE [symmetric memory persistent allocation] -// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent -// allocation. This makes the function cache allocated memory and ensure that -// invocations with the same `alloc_id` receive tensors backed by the same -// memory address. For safety, if a previous persistent allocation is still -// active (i.e., the storage of the returned tensor is still alive), persistent -// allocations with the same `alloc_id` will fail. This determinism coupled -// with memory planning of communication buffers (e.g., by Inductor) allows -// communication algorithms to reliably reuse previously established remote -// memory access. -TORCH_API at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id); - -// Establishes symmetric memory access on tensors allocated via -// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a -// one-time process, and the mapping between a local memory region and the -// associated SymmetricMemory object is unique. Subsequent calls to -// rendezvous() with the same tensor, or tensors allocated with -// empty_strided_p2p_persistent() using the same alloc_id, will receive the -// cached SymmetricMemory object. -// -// The function has a collective semantic and must be invoked simultaneously -// from all rendezvous participants. -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor); - -// Returns the SymmetricMemory object associated with the tensor. It can only -// be invoked after rendezvous() but does not need to be invoked collectively. -TORCH_API c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor); - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index db5778efcf3547..6f1b28886b989b 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include @@ -976,44 +975,6 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); - using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; - py::class_>( - module, "_SymmetricMemory") - .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) - .def_static( - "empty_strided_p2p", - ::c10d::symmetric_memory::empty_strided_p2p, - py::arg("size"), - py::arg("stride"), - py::arg("dtype"), - py::arg("device"), - py::arg("group_name"), - py::arg("alloc_id") = py::none()) - .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) - .def_static( - "get_symmetric_memory", - &::c10d::symmetric_memory::get_symmetric_memory) - .def_property_readonly("rank", &SymmetricMemory::get_rank) - .def_property_readonly("world_size", &SymmetricMemory::get_world_size) - .def( - "get_buffer", - &SymmetricMemory::get_buffer, - py::arg("rank"), - py::arg("sizes"), - py::arg("dtype"), - py::arg("storage_offset") = 0) - .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) - .def( - "put_signal", - &SymmetricMemory::put_signal, - py::arg("dst_rank"), - py::arg("channel") = 0) - .def( - "wait_signal", - &SymmetricMemory::wait_signal, - py::arg("src_rank"), - py::arg("channel") = 0); - auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 9d7ba5abf951dd..85136a91e02564 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,8 +218,23 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - auto allocator = get_allocator(c10::DeviceType::CUDA); - allocator->free(symmetricMemoryPtr_); + // Intentionally releasing resources without synchronizing devices. The + // teardown logic is safe for propoerly sync'd user program. We don't want + // improperly sync'd user program to hang here. + for (size_t r = 0; r < worldSize_; ++r) { + if (r == rank_) { + continue; + } + AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); + AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); + } + AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); + AT_CUDA_CHECK(cudaFree(buffers_[rank_])); + if (topoInfo_ != nullptr) { + AT_CUDA_CHECK(cudaFree(topoInfo_)); + } + AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); + AT_CUDA_CHECK(cudaFree(buffersDev_)); } bool IntraNodeComm::isEnabled() { @@ -329,19 +344,83 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - set_group_info("IntraNodeComm", rank_, worldSize_, store_); - auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = - allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); - symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); - TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); + // Initialize p2p state + auto p2pState = initP2pState(); + + // Allocate buffer + void* buffer = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); + + // Second handshake: exchange topology and CUDA IPC handles + struct IpcInfo { + NvlMesh nvlMesh; + Topology topology; + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + }; + + // Make p2p state and buffer available for IPC + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); + AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); + + IpcInfo ipcInfo{ + .nvlMesh = nvlMesh, + .topology = topology, + .p2pStateHandle = p2pStateHandle, + .bufferHandle = bufferHandle}; + + auto peerIpcInfos = + storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); + + for (const auto& info : peerIpcInfos) { + if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || + info.topology != peerIpcInfos.front().topology) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are observing different topologies (" + << int(info.topology) << " and " << int(topology) << ")"; + AT_CUDA_CHECK(cudaFree(p2pState)); + AT_CUDA_CHECK(cudaFree(buffer)); + return false; + } + } + + std::array p2pStates = {}, buffers = {}; + for (size_t r = 0; r < peerIpcInfos.size(); ++r) { + if (r == rank_) { + p2pStates[r] = p2pState; + buffers[r] = buffer; + } else { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &p2pStates[r], + peerIpcInfos[r].p2pStateHandle, + cudaIpcMemLazyEnablePeerAccess)); + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &buffers[r], + peerIpcInfos[r].bufferHandle, + cudaIpcMemLazyEnablePeerAccess)); + } + } + void* p2pStatesDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); + AT_CUDA_CHECK(cudaMemcpy( + p2pStatesDev, + p2pStates.data(), + sizeof(p2pStates), + cudaMemcpyHostToDevice)); + + void* buffersDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); + AT_CUDA_CHECK(cudaMemcpy( + buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); - buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); + std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); + std::copy(buffers.begin(), buffers.end(), buffers_.begin()); + p2pStatesDev_ = p2pStatesDev; + buffersDev_ = buffersDev; topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index ac751ff7be1e09..51fc6252d2235b 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,8 +132,6 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; -static_assert(sizeof(P2pState) <= kP2pStateSize); - template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -524,7 +522,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -584,7 +582,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -634,7 +632,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -757,7 +755,15 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); + const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); + const auto elementSize = c10::elementSize(dtype); + TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); + auto options = at::TensorOptions().dtype(dtype).device( + at::kCUDA, at::cuda::current_device()); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storageOffset) + .options(options) + .make_tensor(); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index a67df5c34586a0..5d7e2d426d30a1 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,16 +4,12 @@ #include #include #include -#include #include namespace c10d::intra_node_comm { -using namespace c10d::symmetric_memory; - constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; -constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -31,7 +27,6 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -102,8 +97,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - void* symmetricMemoryPtr_ = nullptr; - c10::intrusive_ptr symmetricMemory_ = nullptr; + std::array p2pStates_{}; + std::array buffers_{}; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{}; From 1877b7896c237567285804ecc138bc86180a7ced Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 18 Jun 2024 07:49:05 -0700 Subject: [PATCH 140/171] [checkpoint] Clean up selective activation checkpoint and make public (#125795) ### bc-breaking for existing users of the private API: - Existing policy functions must now change their return value to be [CheckpointPolicy](https://github.com/pytorch/pytorch/blob/c0b40ab42e38a208351911496b7153511304f8da/torch/utils/checkpoint.py#L1204-L1230) Enum instead of bool. - To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy. - Policy function now accepts a `ctx` object instead of `mode` for its first argument. - To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`. - Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint). Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit Memory considerations: - As with the existing SAC, cached values are cleared upon first use. - We error if the user wishes to backward a second time on a region forwarded with SAC enabled. In-place: - We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed. - `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place) Randomness, views - Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors) Tensor object preservation - ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor. Policy function - Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error. - The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3). - The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute. - The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below). Tensors guaranteed to be the same tensor as-is - Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795 Approved by: https://github.com/Chillee, https://github.com/fmassa --- docs/source/checkpoint.rst | 3 + test/dynamo/test_activation_checkpointing.py | 27 +- test/test_autograd.py | 416 ++++++++++++++++++- torch/_higher_order_ops/wrap.py | 6 +- torch/utils/checkpoint.py | 316 +++++++++----- 5 files changed, 643 insertions(+), 125 deletions(-) diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst index f7bc160fa98bd2..8559d8bd73663c 100644 --- a/docs/source/checkpoint.rst +++ b/docs/source/checkpoint.rst @@ -35,3 +35,6 @@ torch.utils.checkpoint .. autofunction:: checkpoint .. autofunction:: checkpoint_sequential .. autofunction:: set_checkpoint_debug_enabled +.. autoclass:: CheckpointPolicy +.. autoclass:: SelectiveCheckpointContext +.. autofunction:: create_selective_checkpoint_contexts diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 14851e51895b40..274e033028451a 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,11 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor -from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( @@ -105,8 +109,11 @@ def op_count(gm): def _get_custom_policy(no_recompute_list=None): - def _custom_policy(mode, func, *args, **kwargs): - return func in no_recompute_list + def _custom_policy(ctx, func, *args, **kwargs): + if func in no_recompute_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE return _custom_policy @@ -530,7 +537,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -580,7 +587,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -650,7 +657,7 @@ def _custom_policy(mode, func, *args, **kwargs): def selective_checkpointing_context_fn(): meta = {} - return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) def gn(x, y): return torch.sigmoid( @@ -698,7 +705,7 @@ def fn(x, y): ) def test_compile_selective_checkpoint_partial_ctx_fn(self): def selective_checkpointing_context_fn(no_recompute_list): - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -751,7 +758,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list), ) @@ -803,7 +810,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -854,7 +861,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) diff --git a/test/test_autograd.py b/test/test_autograd.py index c133ae95b4b3da..e45f5d47c69253 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2,6 +2,7 @@ import collections import contextlib +import functools import gc import io import math @@ -79,8 +80,14 @@ ) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.utils.checkpoint import ( + checkpoint, + checkpoint_sequential, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) from torch.utils.cpp_extension import load_inline +from torch.utils.flop_counter import FlopCounterMode from torch.utils.hooks import RemovableHandle # noqa: TCH001 @@ -13215,6 +13222,413 @@ def fn2(x): self.assertEqual(counter[0], 1) +class TestSelectiveActivationCheckpoint(TestCase): + @unittest.skipIf(not TEST_CUDA, "requires CUDA") + def test_flops_and_mem(self): + # From https://github.com/pytorch/pytorch/pull/126320 + def get_act_mem(f): + out = f() + out.backward() + # Why do one forward and backward? + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so + # it will be able to observe whether an op is cached or not. + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + x = torch.randn(512, 512, requires_grad=True, device="cuda") + y = torch.randn(512, 512, requires_grad=True, device="cuda") + + def fn(x, y): + return torch.mm(x.cos(), y).sin().sum() + + def fn_ac(x, y): + return checkpoint(fn, x, y, use_reentrant=False) + + def fn_sac(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.mm.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn_sac2(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + act_mem_noac = get_act_mem(lambda: fn(x, y)) + bw_flops_noac = get_bw_flops(lambda: fn(x, y)) + + self.assertEqual(act_mem_noac, 2.0) + self.assertEqual(bw_flops_noac, 2.0) + + act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) + bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) + + self.assertEqual(act_mem_ac, 0.0) + self.assertEqual(bw_flops_ac, 3.0) + + act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) + bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) + + self.assertEqual(act_mem_sac, 1.0) + self.assertEqual(bw_flops_sac, 2.0) + + act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) + bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) + + self.assertEqual(act_mem_sac2, 1.0) + self.assertEqual(bw_flops_sac2, 2.0) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_output_already_has_autograd_meta(self): + # View of tensor of non-differentiable dtype still has AutogradMeta + def fn(x, y): + return x.view(-1), y.sin().cos() + + x = torch.tensor([1, 2, 3], dtype=torch.int64) + y = torch.randn(3, requires_grad=True) + + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.view.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + out[1].sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_subclass_dispatching_sizes(self): + # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default + # Caching such metadata ops can be problematic when the following are satisfied: + # + # 1. size/strides are dispatched upon + # 2. our policy saves sizes + ta = torch.randn(6, 2) + + class CustomSizeDynamicShapesTensor(torch.Tensor): + @staticmethod + def __new__(cls, inner): + return torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + inner.size(), + inner.stride(), + None, + None, + inner.dtype, + inner.layout, + inner.device, + False, + inner.requires_grad, + "sizes", + ) + + def __init__(self, inner): + self.inner = inner + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_inner = torch.utils._pytree.tree_map_only( + cls, lambda x: x.inner, args + ) + out_inner = func(*args_inner, **kwargs) + return torch.utils._pytree.tree_map_only( + torch.Tensor, lambda x: cls(x), out_inner + ) + + def policy_fn(ctx, op, *args, **kwargs): + if op is torch.ops.aten.sym_size.default: + # Silently ignored! + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + # We avoid the following case + # + # saved :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12] + # forward :sum ,sum,mul, mul , mul ,view , view + # recompute :sum ,sum,mul, view , view + # + # Views save the shape of their input, so we expect the second + # view to save 12, but because during AC packing during forward + # saves the shapes of the input for metadata checks later, + # we would save the wrong shape during the recompute. + view_out = (x * x.sum()).view(-1).view(4, 3) + self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12]) + return view_out.exp() + + x = torch.randn(4, 3, requires_grad=True) + x_wrapper = CustomSizeDynamicShapesTensor(x) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + + def test_bad_inputs(self): + bad_op_list1 = [2] + + with self.assertRaisesRegex( + ValueError, "Expected op in `op_list` to be an OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list1) + + bad_op_list2 = [torch.ops.aten.sin] + + with self.assertRaisesRegex( + ValueError, "update the OpOverloadPacket to a specific OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list2) + + with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): + create_selective_checkpoint_contexts(2) + + # Dynamo fails for various reasons: + # - some tests using custom op that does not implement Fake + # - dynamo is trying to trace into saved variable hooks unpack hook for some reason + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_policy_with_state(self): + # If I have a stateful callable, state is shared between the original + # forward and the recompute. + counters = [] + + class Policy: + def __init__(self): + self.counter = [0] + self.recompute_counter = [0] + + def __call__(self, ctx, func, *args, **kwargs): + counter = self.recompute_counter if ctx.is_recompute else self.counter + counter[0] += 1 + counters.append(counter[0]) + if counter == 1 and func is torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().sin().sin() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + Policy(), + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + # 1. counter properly reset to 0 for the recompute + # 2. due to early-stop we do not recompute the final op + self.assertEqual(counters, [1, 2, 3, 1, 2]) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_storage_lifetime(self): + from torch.utils._python_dispatch import _get_current_dispatch_mode + from torch.utils.checkpoint import ( + _CachedTorchDispatchMode, + _CachingTorchDispatchMode, + ) + + def policy_fn(ctx, op, *args, **kwargs): + return CheckpointPolicy.MUST_SAVE + + ref = None + + def fn(x): + nonlocal ref + + self.assertIsInstance( + _get_current_dispatch_mode(), + (_CachingTorchDispatchMode, _CachedTorchDispatchMode), + ) + + out = x.cos().exp() + + if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode): + raw_val = ( + _get_current_dispatch_mode() + .storage[torch.ops.aten.exp.default][0] + .val + ) + # ref should've been detached + # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph + self.assertFalse(raw_val.requires_grad) + ref = weakref.ref(raw_val) + + # Careful for early-stop + return out.sin() + + with disable_gc(): + # Case 1: If graph goes away without backward, make sure there's no reference cycle + # keeping storage alive. + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + del out + self.assertIsNone(ref()) + + # Case 2: After backward, even if retain_graph=True, the storage should go away + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + out.sum().backward(retain_graph=True) + # The dispatch mode's storage should still be alive, but the entries should've + # been cleared. + self.assertIsNone(ref()) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_version_counter(self): + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().mul_(2).cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 1) Error because the output of sin is saved and mutated by mul_ + with self.assertRaisesRegex(RuntimeError, "has been mutated"): + out.sum().backward() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 2) No longer should be an error because of allow_cache_entry_mutation + out.sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_more_than_one_output(self): + # maybe there is a more systematic way: + counter = [0] + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.var_mean.correction: + counter[0] += 1 + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + # var_mean has two outputs + def fn(x): + a, b = torch.var_mean(x) + return a * b + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + self.assertEqual(counter[0], 2) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_non_tensor_output(self): + # When SAC is enabled, the op is not computed a second time + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + counter = [0] + + @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) + def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + counter[0] += 1 + return x.sin(), 2 + + def setup_context(ctx, inputs, output) -> torch.Tensor: + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad, _unused): + (x,) = ctx.saved_tensors + return grad * x.cos() + + torch.library.register_autograd( + "mylib::sin_with_extra", backward, setup_context=setup_context + ) + + x = torch.randn(3, requires_grad=True) + + def fn(x): + return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() + + ops_list = [torch.ops.mylib.sin_with_extra.default] + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_list + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + self.assertEqual(counter[0], 1) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_can_only_trigger_recompute_once(self): + # We don't support this to avoid adding extra complexity for now. + # If there's a need, we could probably do some kind of use_count tracking. + # TODO: have a nice error message here. + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward(retain_graph=True) + + with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): + out.sum().backward(retain_graph=True) + + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): t = torch.randn(3, 3, device=device, requires_grad=True) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 6d83a44e752a06..e7fe553387d1c8 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs import inspect +import itertools import logging import torch from torch._ops import HigherOrderOperator -from torch.utils.checkpoint import checkpoint, uid +from torch.utils.checkpoint import checkpoint + import torch._dynamo.config log = logging.getLogger(__name__) - +uid = itertools.count(1) # Used for testing the HigherOrderOperator mechanism class Wrap(HigherOrderOperator): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 5cbfd1543cf423..dab7730d84397d 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -5,18 +5,8 @@ import warnings import weakref from collections import defaultdict -from itertools import count -from typing import ( - Any, - Callable, - ContextManager, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import * # noqa: F403 +import enum from weakref import ReferenceType import torch @@ -39,6 +29,10 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", + "SAC_IGNORED_OPS", ] _DEFAULT_DETERMINISM_MODE = "default" @@ -1153,149 +1147,247 @@ def _is_compiling(func, args, kwargs): return False -def _detach(x): - if isinstance(x, torch.Tensor): - return x.detach() +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x, any_ret_has_alias_info): + # We detach for two separate reasons: + # - For view ops, we need to ensure that when the tensor is returned from + # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr + # - Avoid reference cycles + # For case 1, it is not enough to check whether x has differentiable dtype + # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. + # when the tensor is a view. + if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() return x -uid = count(1) +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. -# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times -# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call), -# so we ignore these ops and just always recompute them. -_ignored_ops = { - torch.ops.prim.device.default, + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute + + +class CheckpointPolicy(enum.Enum): + """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +SAC_IGNORED_OPS = { + # AC inserts different number of detach during forward and recompute. torch.ops.aten.detach.default, + # AC's determinism check invokes additional metadata ops during forward. + # With subclasses involved, these metadata ops become dispatchable, this + # can result in incorrectness if these ops are selected cached. + torch.ops.prim.device.default, } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) class _CachingTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachedTorchDispatchMode. - """ + # Used together with _CachedTorchDispatchMode to implement SAC. def __init__(self, policy_fn, storage): self.policy_fn = policy_fn self.storage = storage - def push_into_storage(self, out, func, args, kwargs): - out_detached = tree_map(_detach, out) - self.storage[func].append(out_detached) + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) - def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs): - if should_not_recompute: + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if is_compiling and policy == CheckpointPolicy.MUST_SAVE: fx_traceback.current_meta["recompute"] = 0 - # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode - # we decide and handle recomputation in the partitioner. + out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: - return func(*args, **kwargs) - should_not_recompute = self.policy_fn("forward", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs) - else: - if should_not_recompute: - out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + return out class _CachedTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachingTorchDispatchMode. - """ - def __init__(self, policy_fn, storage): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): self.policy_fn = policy_fn self.storage = storage - - def pop_from_storage(self, func, args, kwargs): - assert func in self.storage - out = self.storage[func].pop(0) - return out - - def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs): - out = self.pop_from_storage(func, args, kwargs) - return out + self.allow_cache_entry_mutation = allow_cache_entry_mutation def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: + if func in SAC_IGNORED_OPS: return func(*args, **kwargs) - should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) else: - if should_not_recompute: - out = self.pop_from_storage(func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + out = func(*args, **kwargs) + return out -def _pt2_selective_checkpoint_context_fn_gen(policy_fn): + +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): """ - A helper function that generates a pair of contexts to be later passed into - `torch.utils.checkpoint` API to implment selective checkpointing. + Helper to avoid recomputing certain ops during activation checkpointing. - .. warning:: - This is context_fn is intended for use with torch.compile only. + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. Args: - policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function - to decide whether a particular op should be recomputed in backward pass or not. - In eager mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op is guaranteed to be recomputed. - In torch.compile mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op may or may not be recomputed - (it's up to the partitioner to decide). - + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. Returns: - A pair of generated contexts. + A tuple of two context managers. Example: >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools >>> - >>> def get_custom_policy(): - >>> no_recompute_list = [ - >>> torch.ops.aten.mm.default, - >>> ] - >>> def custom_policy(mode, func, *args, **kwargs): - >>> return func in no_recompute_list - >>> return custom_policy + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) >>> - >>> def selective_checkpointing_context_fn(): - >>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy()) + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] >>> - >>> def gn(x, y): - >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> - >>> def fn(x, y): - >>> return torch.utils.checkpoint.checkpoint( - >>> gn, x, y, - >>> use_reentrant=False, - >>> context_fn=selective_checkpointing_context_fn, - >>> ) + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) >>> - >>> x = torch.randn(4, 4, requires_grad=True) - >>> y = torch.randn(4, 4, requires_grad=True) + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> - >>> compiled_fn = torch.compile(fn) + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) """ - storage: Dict[Any, List[Any]] = defaultdict(list) - return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage) + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, torch._ops.OpOverload): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here. From d77a1aaa8623ba5e70f4f147362d84769784cf43 Mon Sep 17 00:00:00 2001 From: loganthomas Date: Tue, 18 Jun 2024 18:26:07 +0000 Subject: [PATCH 141/171] DOC: add note about same sized tensors to dist.gather() (#128676) Fixes #103305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128676 Approved by: https://github.com/wconstab --- torch/distributed/distributed_c10d.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bd81fd61b02f91..d44c3733a214e6 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3041,11 +3041,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): """ Gathers tensors from the whole group in a list. - Complex tensors are supported. + Complex and uneven sized tensors are supported. Args: tensor_list (list[Tensor]): Output list. It should contain correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. tensor (Tensor): Tensor to be broadcast from current process. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. @@ -3118,6 +3119,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal """ Gather tensors from all ranks and put them in a single output tensor. + This function requires all tensors to be the same size on each process. + Args: output_tensor (Tensor): Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the @@ -3341,11 +3344,13 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): """ Gathers a list of tensors in a single process. + This function requires all tensors to be the same size on each process. + Args: tensor (Tensor): Input tensor. - gather_list (list[Tensor], optional): List of appropriately-sized - tensors to use for gathered data (default is None, must be specified - on the destination rank) + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). (default is 0) group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. From 1a527915a64b8e5f60951715b09fa294b1a8844f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 09:54:11 -0700 Subject: [PATCH 142/171] [DSD] Correctly handle shared parameters for optimizer state_dict (#128685) * Fixes https://github.com/pytorch/pytorch/issues/128011 See the discussion in https://github.com/pytorch/pytorch/pull/128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128685 Approved by: https://github.com/LucasLLC --- .../distributed/checkpoint/test_state_dict.py | 27 ++++++++++++ torch/distributed/checkpoint/state_dict.py | 42 ++++++++++++++++--- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 3da18ea5cc600f..ac6263569af45d 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -851,6 +851,33 @@ def test_deprecate_fsdp_api(self) -> None: ): get_model_state_dict(model) + @with_comms + @skip_if_lt_x_gpu(2) + def test_shared_weight(self): + class TiedEmbeddingModel(nn.Module): + def __init__(self, vocab_size, embedding_dim): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + self.decoder = nn.Linear(embedding_dim, vocab_size) + self.decoder.weight = self.embedding.weight # Tying weights + + def forward(self, input): + input = (input * 10).to(torch.int) + embedded = self.embedding(input) + output = self.decoder(embedded) + return output + + def init_model_optim(): + device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) + orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) + dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) + return orig_model, orig_optim, copy_optim, dist_model, dist_optim + + self._test_save_load(init_model_optim) + class TestNoComm(MultiProcessTestCase): def setUp(self) -> None: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16a1ddde215869..6bdeb389e8a0c7 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -153,6 +153,9 @@ class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] + ] = field(default_factory=dict) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True handle_optim: bool = True @@ -286,14 +289,29 @@ def _verify_options( fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[Set[str], torch.Tensor] ] = {} + shared_params_mapping: Dict[ + Union[str, torch.Tensor], Union[Set[str], torch.Tensor] + ] = {} for name, param in _iterate_valid_model_state(model): + if isinstance(param, _EXTRA_STATE): + continue + fqns = _get_fqns(model, name) - if not isinstance(param, _EXTRA_STATE): - fqn_param_mapping[param] = fqns + fqn = fqn_param_mapping.get(param, None) + if fqn is not None: + cast(Set[str], fqn_param_mapping[param]).update(fqns) + shared_params_mapping[param] = fqn_param_mapping[param] + else: + # We need to do copy as _get_fqns is lru_cached + fqn_param_mapping[param] = fqns.copy() for fqn in fqns: if not isinstance(param, _EXTRA_STATE): fqn_param_mapping[fqn] = param + for param_, fqns_ in list(shared_params_mapping.items()): + for fqn in fqns_: + shared_params_mapping[fqn] = cast(torch.Tensor, param_) + submodule_prefixes: Set[str] = set() if submodules: submodules = set(submodules) @@ -361,6 +379,7 @@ def fsdp_state_dict_type_without_warning( return _StateDictInfo( **asdict(options), fqn_param_mapping=fqn_param_mapping, + shared_params_mapping=shared_params_mapping, submodule_prefixes=submodule_prefixes, fsdp_context=fsdp_context, fsdp_modules=cast(List[nn.Module], fsdp_modules), @@ -450,7 +469,7 @@ def _get_model_state_dict( for key in list(state_dict.keys()): fqns = _get_fqns(model, key) - assert len(fqns) == 1 + assert len(fqns) == 1, (key, fqns) fqn = next(iter(fqns)) if fqn != key: # As we only support FSDP, DDP, and TP, the only cases are @@ -797,6 +816,19 @@ def _split_optim_state_dict( pg_state.append({_PARAMS: []}) for param in param_group[_PARAMS]: for fqn in info.fqn_param_mapping[param]: + if fqn in info.shared_params_mapping: + in_params = False + for loaded_param_group in cast( + ListDictValueType, optim_state_dict[_PG] + ): + if fqn in cast(List[str], loaded_param_group[_PARAMS]): + in_params = True + break + else: + in_params = True + if not in_params: + continue + params = pg_state[-1][_PARAMS] assert isinstance(params, list) params.append(fqn) @@ -805,9 +837,7 @@ def _split_optim_state_dict( for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): - params = loaded_param_group[_PARAMS] - assert isinstance(params, list) - if fqn in params: + if fqn in cast(List[str], loaded_param_group[_PARAMS]): pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 for param_group in cast(ListDictValueType, optim_state_dict[_PG]): From bdffd9f0c6f4564ee0cdd15d030215b5df58b2a9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 17 Jun 2024 23:10:58 -0700 Subject: [PATCH 143/171] [export] Graph break on nn.Parameter construction (#128935) Fixes https://github.com/pytorch/pytorch/issues/126109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128935 Approved by: https://github.com/angelayi --- torch/_dynamo/variables/torch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 74c2193646bc0b..1cc4622dea529c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -877,6 +877,9 @@ def handle_ntuple(value): @classmethod def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" + if tx.export: + unimplemented("nn parameter construction not supported with export") + if isinstance(requires_grad, variables.VariableTracker): try: requires_grad = requires_grad.as_python_constant() From 44483972bdd3dcd0c047020694817210846b5d70 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:37 -0700 Subject: [PATCH 144/171] [EZ] Keep weight_norm var name aligned (#128955) To keep it aligned with https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/native_functions.yaml#L6484 I.e. `x`->`v`, `y`->`g` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128955 Approved by: https://github.com/albanD, https://github.com/Skylion007 --- torch/_decomp/decompositions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7ebc69462fa1c1..dca552137ca6d3 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4770,11 +4770,11 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): @register_decomposition(torch.ops.aten._weight_norm_interface) -def _weight_norm_interface(x, y, dim=0): +def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 - keep_dim = tuple(i for i in range(len(x.shape)) if i != dim) - norm = x.norm(2, keep_dim, keepdim=True) - return x * (y / norm), norm + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + norm = v.norm(2, keep_dim, keepdim=True) + return v * (g / norm), norm @register_decomposition(aten.isin) From 04a5d3228ecd5af790dabcfeb27c8c4f86742e11 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 18 Jun 2024 19:11:04 +0000 Subject: [PATCH 145/171] [ts migration] Support prim::tolist and aten::len (#128894) Support prim::tolist and aten::len. Add unit tests for prim::min. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128894 Approved by: https://github.com/angelayi --- test/export/test_converter.py | 106 +++++++++++++++++++++++++++++++++- torch/_export/converter.py | 12 +++- 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 300f70223a26b7..8ea6a8089ae8b0 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -111,13 +111,102 @@ def forward(self, x): def test_aten_len(self): class Module(torch.nn.Module): - def forward(self, x): + def forward(self, x: torch.Tensor): length = len(x) return torch.ones(length) + # aten::len.Tensor inp = (torch.ones(2, 3),) self._check_equal_ts_ep_converter(Module(), inp) + class Module(torch.nn.Module): + def forward(self, x: List[int]): + length = len(x) + return torch.ones(length) + + # aten::len.t + inp = ([1, 2, 3],) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[int, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_int + inp = ({1: "a", 2: "b", 3: "c"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[bool, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_bool + inp = ({True: "a", False: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[float, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_float + inp = ({1.2: "a", 3.4: "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: Dict[torch.Tensor, str]): + length = len(x) + return torch.ones(length) + + # aten::len.Dict_Tensor + inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + # aten::len.str and aten::len.Dict_str are not supported + # since torch._C._jit_flatten does not support str + # inp = ("abcdefg",) + # self._check_equal_ts_ep_converter(Module(), inp) + # inp = ({"a": 1, "b": 2},) + # self._check_equal_ts_ep_converter(Module(), inp) + + def test_prim_min(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_len = len(x) + y_len = len(y) + + # prim::min.int + len_int = min(x_len, y_len) + + # prim::min.float + len_float = int(min(x_len * 2.0, y_len * 2.0)) + + # prim::min.self_int + len_self_int = min([x_len, y_len]) + + # prim::min.self_float + len_self_float = int(min([x_len * 2.0, y_len * 2.0])) + + # prim::min.float_int + len_float_int = int(min(x_len * 2.0, y_len)) + + # prim::min.int_float + len_int_float = int(min(x_len, y_len * 2.0)) + + return torch.ones( + len_int + + len_float + + len_self_int + + len_self_float + + len_float_int + + len_int_float + ) + + inp = (torch.randn(10, 2), torch.randn(5)) + self._check_equal_ts_ep_converter(Module(), inp) + def test_aten___getitem___list(self): class Module(torch.nn.Module): def forward(self, x): @@ -659,6 +748,21 @@ def forward(self, x): # inp = (torch.randn([2, 3, 4]),) # self._check_equal_ts_ep_converter(func6, inp) + def test_prim_tolist(self): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[int]: + return x.tolist() + + inp = (torch.tensor([1, 2, 3]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor) -> List[List[int]]: + return x.tolist() + + inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) + self._check_equal_ts_ep_converter(Module(), inp, ["script"]) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 2c54db38dee8b8..48f983b2917ef8 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -91,6 +91,7 @@ def get_dtype_as_int(tensor): "aten::__not__": operator.not_, "aten::__contains__": operator.contains, "prim::dtype": get_dtype_as_int, + "aten::len": len, } @@ -187,7 +188,7 @@ def _map_blocks_to_lifted_attrs(entry): def get_op_overload(node: torch._C.Node): schema_str = node.schema() - schema = torch._C.parse_schema(schema_str) + schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) ns, op_name = str(schema.name).split("::") override = schema.overload_name @@ -651,6 +652,15 @@ def convert_profiler__record_function_exit(self, node: torch._C.Node): args = tuple(self.get_fx_value(input) for input in node.inputs()) self.fx_graph.call_function(target, args) + def convert_prim_tolist(self, node: torch._C.Node): + # prim::tolist cannot be supported by `_convert_standard_operators` + # since it requires call_method instead of call_function. + target = "tolist" + args = (self.get_fx_value(next(node.inputs())),) + fx_node = self.fx_graph.call_method(target, args) + output_name = node.output().debugName() + self.name_to_node[output_name] = fx_node + def _convert_standard_operators(self, node: torch._C.Node): target = kind_to_standard_operators[node.kind()] args = tuple(self.get_fx_value(input) for input in node.inputs()) From abde6cab4c7f972672ae008223000c16fd3964cd Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 12 Jun 2024 19:33:15 -0700 Subject: [PATCH 146/171] Remove compile_threads=1 in test_inductor_collectives.py (#128580) Summary: I believe https://github.com/pytorch/pytorch/issues/125235 should be fixed after switching to subprocess-based parallel compile. Test Plan: Ran locally with python-3.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128580 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 35e44b19bedd55..ee4535fd5a73f1 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -60,8 +60,6 @@ def world_size(self) -> int: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_broadcast_inductor(self): """ Testing if broadcast works correctly when using inductor @@ -94,8 +92,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -129,8 +125,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_inductor_cudagraph_trees(self): """ Tests whether cudagraph trees support all_reduce from nccl @@ -177,8 +171,6 @@ def test_c10d_functional_tagged_pt2_compliant(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_eager_allreduce_inductor_wait(self): def eager_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -218,8 +210,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_inductor_allreduce_eager_wait(self): def inductor_func(a, b, c, d, *, tag, ranks, group_size): x = torch.matmul(a, b) @@ -256,8 +246,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allreduce_input_buffer_reuse(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) @@ -275,8 +263,6 @@ def func(a, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_permute_tensor(self): def func(tensor, src_dst_pairs, *, tag, ranks, group_size): return _functional_collectives.permute_tensor( @@ -304,8 +290,6 @@ def func(tensor, src_dst_pairs, *, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._inductor.config, "allow_buffer_reuse", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_output_buffer_reuse(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -329,8 +313,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_contiguous_input(self): class Model(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -355,8 +337,6 @@ def forward(self, x, world_size, tag, ranks, group_size): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_allgather_into_tensor_inductor(self): """ This is matmul/cat/allreduce is a pattern we aim to optimize. @@ -388,8 +368,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_reduce_scatter_tensor_inductor(self): def example(a, b, *, tag, ranks, group_size): c = torch.matmul(a, b) @@ -418,8 +396,6 @@ def compile(func, example_inputs): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor(self): def example( inp, @@ -488,8 +464,6 @@ def example( @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) - # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor - @patch.object(torch._inductor.config, "compile_threads", 1) def test_all_to_all_single_inductor_split_sizes_none(self): def example(inp, *, tag, ranks, group_size): a2a = torch.ops.c10d_functional.all_to_all_single( From fe8558b7aa4ce55d06893c48d5cb00b7a7eb7dae Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 17 Jun 2024 10:37:13 -0700 Subject: [PATCH 147/171] [DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755) HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case. This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446. Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755 Approved by: https://github.com/Skylion007, https://github.com/wz337 ghstack dependencies: #128685 --- .../distributed/checkpoint/test_state_dict.py | 157 ++++++++++-------- 1 file changed, 87 insertions(+), 70 deletions(-) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index ac6263569af45d..77363506288027 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -33,7 +33,11 @@ set_optimizer_state_dict, StateDictOptions, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, + StateDictType, +) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.optim import _apply_optimizer_in_backward from torch.nn.parallel import DistributedDataParallel as DDP @@ -70,7 +74,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: - return 2 + return min(4, torch.cuda.device_count()) def _test_save_load( self, @@ -567,55 +571,71 @@ def test_non_persistent_buffers(self) -> None: set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model)) - @with_comms - @skip_if_lt_x_gpu(2) - def test_broadcast_from_rank0(self) -> None: - def inner_test(wrapper): - model = CompositeParamModel(device=torch.device("cuda")) - optim = torch.optim.Adam(model.parameters()) - fsdp_model = wrapper(copy.deepcopy(model)) - fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) + def _test_broadcast_from_rank0(self, wrapper) -> None: + model = CompositeParamModel(device=torch.device("cuda")) + optim = torch.optim.Adam(model.parameters()) + fsdp_model = wrapper(copy.deepcopy(model)) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") - model(batch).sum().backward() - optim.step() - states, optim_states = get_state_dict(model, optim) + batch = torch.rand(8, 100, device="cuda") + model(batch).sum().backward() + optim.step() + states, optim_states = get_state_dict(model, optim) - fsdp_model(batch).sum().backward() - fsdp_optim.step() + fsdp_model(batch).sum().backward() + fsdp_optim.step() - def check(equal): - fsdp_states = get_model_state_dict( - fsdp_model, - options=StateDictOptions(full_state_dict=True), - ) - fsdp_optim_states = get_optimizer_state_dict( - fsdp_model, - fsdp_optim, - options=StateDictOptions(full_state_dict=True), - ) - if equal: - self.assertEqual(states, fsdp_states) - self.assertEqual(optim_states, fsdp_optim_states) - else: - self.assertNotEqual(states, fsdp_states) - self.assertNotEqual(optim_states, fsdp_optim_states) - - check(equal=True) - fsdp_model(batch).sum().backward() - fsdp_optim.step() - check(equal=False) - - # Drop the states to simulate loading from rank0 - if dist.get_rank() > 0: - load_states = {} - load_states2 = {} - load_optim_states = {} + def check(equal): + fsdp_states = get_model_state_dict( + fsdp_model, + options=StateDictOptions(full_state_dict=True), + ) + fsdp_optim_states = get_optimizer_state_dict( + fsdp_model, + fsdp_optim, + options=StateDictOptions(full_state_dict=True), + ) + if equal: + self.assertEqual(states, fsdp_states) + self.assertEqual(optim_states, fsdp_optim_states) else: - load_states = copy.deepcopy(states) - load_states2 = copy.deepcopy(states) - load_optim_states = copy.deepcopy(optim_states) + self.assertNotEqual(states, fsdp_states) + self.assertNotEqual(optim_states, fsdp_optim_states) + + check(equal=True) + fsdp_model(batch).sum().backward() + fsdp_optim.step() + check(equal=False) + + # Drop the states to simulate loading from rank0 + if dist.get_rank() > 0: + load_states = {} + load_states2 = {} + load_optim_states = {} + else: + load_states = copy.deepcopy(states) + load_states2 = copy.deepcopy(states) + load_optim_states = copy.deepcopy(optim_states) + set_model_state_dict( + fsdp_model, + model_state_dict=load_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + set_optimizer_state_dict( + fsdp_model, + fsdp_optim, + optim_state_dict=load_optim_states, + options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True), + ) + + check(equal=True) + # Verify the `strict` flag. + load_states = load_states2 + if load_states: + key = next(iter(load_states.keys())) + load_states.pop(key) + with self.assertRaisesRegex(RuntimeError, "Missing key"): set_model_state_dict( fsdp_model, model_state_dict=load_states, @@ -623,30 +643,10 @@ def check(equal): broadcast_from_rank0=True, full_state_dict=True ), ) - set_optimizer_state_dict( - fsdp_model, - fsdp_optim, - optim_state_dict=load_optim_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) - - check(equal=True) - # Verify the `strict` flag. - load_states = load_states2 - if load_states: - key = next(iter(load_states.keys())) - load_states.pop(key) - with self.assertRaisesRegex(RuntimeError, "Missing key"): - set_model_state_dict( - fsdp_model, - model_state_dict=load_states, - options=StateDictOptions( - broadcast_from_rank0=True, full_state_dict=True - ), - ) + @with_comms + @skip_if_lt_x_gpu(2) + def test_broadcast_from_rank0(self) -> None: device_mesh = init_device_mesh("cuda", (self.world_size,)) self.run_subtests( { @@ -655,7 +655,24 @@ def check(equal): functools.partial(FSDP, device_mesh=device_mesh), ] }, - inner_test, + self._test_broadcast_from_rank0, + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_broadcast_from_rank0_hsdp(self) -> None: + device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) + self.run_subtests( + { + "wrapper": [ + functools.partial( + FSDP, + device_mesh=device_mesh, + sharding_strategy=ShardingStrategy.HYBRID_SHARD, + ), + ] + }, + self._test_broadcast_from_rank0, ) @with_comms From 9a7e2519d3d15f8d469b71cab914fcdaf071ebd6 Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 18 Jun 2024 19:59:50 +0000 Subject: [PATCH 148/171] [MPS] Fused Adam & AdamW (#127242) Summary: This PR adds fused Adam and AdamW implementations. Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory: **Fast math enabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300 Times are in milliseconds (ms). ``` **Fast math disabled:** ``` [---------------------------------------------- Fused Adam ----------------------------------------------] | Fused: True | Fused: False 1 threads: ----------------------------------------------------------------------------------------------- amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500 amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900 amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850 amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900 amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800 amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910 amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900 amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900 amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900 amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000 amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000 amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500 amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200 Times are in milliseconds (ms). ``` ```python def profile_fused_adam(): from torch.optim import adam, adamw import torch.utils.benchmark as benchmark import itertools def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused): fn( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach=False, capturable=False, fused=fused, amsgrad=amsgrad, beta1=0.9, beta2=0.99, lr=1e-3, weight_decay=.0, eps=1e-5, maximize=False, grad_scale=None, found_inf=None, ) torch.mps.synchronize() device = "mps" results = [] for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]): print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}") params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)] max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else [] state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)] if adamWflag: fn = adamw.adamw else: fn = adam.adam for fused in [True, False]: t = benchmark.Timer( stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)', label='Fused Adam', sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}", globals=locals(), description= f"Fused: {fused}", ).blocked_autorange(min_run_time=5) results.append(t) compare = benchmark.Compare(results) compare.trim_significant_figures() compare.colorize(rowwise=True) compare.print() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242 Approved by: https://github.com/kulinseth, https://github.com/janeyx99 --- aten/src/ATen/native/mps/OperationUtils.h | 19 +- aten/src/ATen/native/mps/OperationUtils.mm | 31 +- .../operations/FusedAdamAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamKernel.mm | 69 +++++ .../mps/operations/FusedAdamKernelImpl.h | 23 ++ .../mps/operations/FusedAdamKernelImpl.mm | 35 +++ .../operations/FusedAdamWAmsgradKernelImpl.h | 24 ++ .../operations/FusedAdamWAmsgradKernelImpl.mm | 37 +++ .../native/mps/operations/FusedAdamWKernel.mm | 68 +++++ .../mps/operations/FusedAdamWKernelImpl.h | 23 ++ .../mps/operations/FusedAdamWKernelImpl.mm | 35 +++ .../native/mps/operations/FusedOptimizerOps.h | 274 ++++++++++++++++++ .../native/mps/operations/MultiTensorApply.h | 190 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 34 +-- test/test_optim.py | 31 +- torch/optim/adam.py | 6 + torch/optim/adamw.py | 6 + torch/testing/_internal/common_optimizers.py | 4 +- torch/utils/_foreach_utils.py | 2 +- 21 files changed, 911 insertions(+), 63 deletions(-) create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h create mode 100644 aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm create mode 100644 aten/src/ATen/native/mps/operations/FusedOptimizerOps.h create mode 100644 aten/src/ATen/native/mps/operations/MultiTensorApply.h diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 25e86e6d262f98..a9493cbce3adad 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -336,25 +336,34 @@ inline bool is_dense_in_storage(const at::Tensor& t) { class MetalShaderLibrary { public: - MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {} + MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){} + MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} MetalShaderLibrary(const MetalShaderLibrary&) = delete; inline id getPipelineStateForFunc(const std::string& fname) { - return getLibraryPipelineState(getLibrary(), fname); + return getLibraryPipelineState(getLibrary(), fname).first; } id getPipelineStateForFunc(const std::string& fname, const std::initializer_list& params) { - return getLibraryPipelineState(getLibrary(params), fname); + return getLibraryPipelineState(getLibrary(params), fname).first; + } + inline id getMTLFunction(const std::string& fname) { + return getLibraryPipelineState(getLibrary(), fname).second; + } + id getMTLFunction(const std::string& fname, const std::initializer_list& params) { + return getLibraryPipelineState(getLibrary(params), fname).second; } private: - id getLibraryPipelineState(id lib, const std::string& fname); + std::pair, id> getLibraryPipelineState(id lib, const std::string& fname); id getLibrary(); id getLibrary(const std::initializer_list& params); id compileLibrary(const std::string& src); std::string shaderSource; unsigned nparams; + MTLCompileOptions* compile_options; id library = nil; std::unordered_map> libMap; - std::unordered_map> cplMap; + std::unordered_map, id>> cplMap; }; static inline void mtl_setBuffer(id encoder, const Tensor& t, unsigned idx) { diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 82d1fe9d92f48c..8dc90e497fe4ee 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -656,31 +656,38 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} id MetalShaderLibrary::compileLibrary(const std::string& src) { NSError* error = nil; - MTLCompileOptions* options = [[MTLCompileOptions new] autorelease]; - [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 - : MTLLanguageVersion2_3]; - // [options setFastMathEnabled: NO]; - auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; + MTLCompileOptions* options = compile_options; + if (!options) { + options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 + : MTLLanguageVersion2_3]; + [options setFastMathEnabled:NO]; + } + + const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; auto device = MPSDevice::getInstance()->device(); library = [device newLibraryWithSource:str options:options error:&error]; TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); return library; } -id MetalShaderLibrary::getLibraryPipelineState(id lib, const std::string& fname) { - auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); - auto cpl = cplMap[key]; - if (cpl) { - return cpl; +std::pair, id> MetalShaderLibrary::getLibraryPipelineState( + id lib, + const std::string& fname) { + const auto key = fmt::format("{}:{}", reinterpret_cast(lib), fname); + auto found_cpl = cplMap.find(key); + if (found_cpl != cplMap.end()) { + return found_cpl->second; } NSError* error = nil; id func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; TORCH_CHECK(func, "Failed to create function state object for: ", fname); - cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; + auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); - return cplMap[key] = cpl; + cplMap[key] = std::make_pair(cpl, func); + return cplMap[key]; } } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h new file mode 100644 index 00000000000000..8711cb228ee9f4 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm new file mode 100644 index 00000000000000..be6069ad9694b8 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm new file mode 100644 index 00000000000000..2e4d89ff851c37 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernel.mm @@ -0,0 +1,69 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adam_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adam_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h new file mode 100644 index 00000000000000..90d1ee1509323d --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm new file mode 100644 index 00000000000000..e3c87ae9bc7872 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adam_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adam_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h new file mode 100644 index 00000000000000..f03fcdb5741398 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm new file mode 100644 index 00000000000000..fd94e9686fbce0 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.mm @@ -0,0 +1,37 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_amsgrad_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{ + params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<5, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm new file mode 100644 index 00000000000000..ce08972ef9adf5 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm @@ -0,0 +1,68 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at::native { + +void _fused_adamw_kernel_mps_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool amsgrad, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (amsgrad) { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_amsgrad_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } else { + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}), + "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); + mps::_fused_adamw_mps_impl_(params, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale, + found_inf); + } +} +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h new file mode 100644 index 00000000000000..284516e0b89ce4 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf +); +} //namespace mps +}// namespace at::native diff --git a/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm new file mode 100644 index 00000000000000..8899f6a5e9e130 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm @@ -0,0 +1,35 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include + +namespace at::native { +namespace mps { + +void _fused_adamw_mps_impl_(at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + std::vector> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()}; + + const std::string kernel_name = "fused_adamw_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" + + scalarToMetalTypeString(state_steps[0].scalar_type()); + + multi_tensor_apply_for_fused_adam<4, 512>( + kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize); +} +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h new file mode 100644 index 00000000000000..00a75067b7f4b3 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/FusedOptimizerOps.h @@ -0,0 +1,274 @@ +#pragma once +#include + +namespace at::native { +namespace mps { + +static const char* FUSED_ADAM_OPS = R"METAL( +#include + +#define kmaxThreadGroups 32 +#define kmaxTensors 32 +#define chunk_size 65536 + +constexpr constant uint kParamIdx = 0; +constexpr constant uint kGradIdx = kParamIdx + kmaxTensors; +constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors; +constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors; +constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors; +constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors; + +template +struct AdamArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdx) ]]; +}; + +template +struct AdamAmsgradArguments { + metal::array params [[ id(kParamIdx) ]]; + metal::array grads [[ id(kGradIdx) ]]; + metal::array exp_avgs [[ id(kExpAvgIdx) ]]; + metal::array exp_avg_sqs [[ id(kExpAvgSqIdx) ]]; + metal::array max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]]; + metal::array state_steps [[ id(kStateStepsIdxForAmsgrad) ]]; +}; + +struct MetadataArguments { + uint32_t numels[kmaxTensors]; + uint32_t threadgroup_to_tensor[kmaxThreadGroups]; + uint32_t threadgroup_to_chunk[kmaxThreadGroups]; +}; + +enum ADAM_MODE : uint8_t { + ORIGINAL = 0, + ADAMW = 1 +}; + +template +inline void adam_math_amsgrad( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device T & max_exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq); + + const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +inline void adam_math( + device T & param, + device T & grad, + device T & exp_avg, + device T & exp_avg_sq, + device state_steps_t & state_steps, + const float lr, + const float beta1, + const float beta2, + const float weight_decay, + const float eps, + const uint8_t maximize +) { + T grad_ = grad; + + if (maximize) { + grad = -grad; + } + + // Update param, grad, 1st and 2nd order momentum. + if (weight_decay != 0) { + switch (adam_mode) { + case ADAM_MODE::ORIGINAL: + grad += param * weight_decay; + break; + case ADAM_MODE::ADAMW: + param -= lr * weight_decay * param; + break; + } + } + + exp_avg = beta1 * exp_avg + (1 - beta1) * grad; + exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad; + const float casted_state_steps = static_cast(state_steps); + const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps); + const T step_size = lr / bias_correction1; + const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps); + const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2); + const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps; + param -= step_size * exp_avg / denom; + grad = grad_; +} + +template +kernel void fused_adam_amsgrad( + device AdamAmsgradArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math_amsgrad( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *(max_exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +template +kernel void fused_adam( + device AdamArguments & args [[buffer(0)]], + constant MetadataArguments & metadata_args [[buffer(1)]], + constant float & lr [[buffer(2)]], + constant float & beta1 [[buffer(3)]], + constant float & beta2 [[buffer(4)]], + constant float & weight_decay [[buffer(5)]], + constant float & eps [[buffer(6)]], + constant uint8_t & maximize [[buffer(7)]], + uint tid [[thread_position_in_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint tptg [[threads_per_threadgroup]]) { + + const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid]; + const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid]; + const uint32_t chunk_offset = chunk_idx * chunk_size; + const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset; + + const auto step_count = args.state_steps[tensor_loc]; + + // each chunk is a threadgroup + auto param = args.params[tensor_loc] + chunk_offset; + auto grad = args.grads[tensor_loc] + chunk_offset; + auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset; + auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset; + + for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) { + adam_math( + *(param + i_start), + *(grad + i_start), + *(exp_avg + i_start), + *(exp_avg_sq + i_start), + *step_count, + lr, + beta1, + beta2, + weight_decay, + eps, + maximize + ); + } +} + +#define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \ +template \ +[[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \ +kernel void KERNEL_NAME( \ + device ARGUMENTS_STRUCT & args [[buffer(0)]],\ + constant MetadataArguments & metadata_args [[buffer(1)]],\ + constant float & lr [[buffer(2)]],\ + constant float & beta1 [[buffer(3)]],\ + constant float & beta2 [[buffer(4)]],\ + constant float & weight_decay [[buffer(5)]],\ + constant float & eps [[buffer(6)]],\ + constant uint8_t & maximize [[buffer(7)]],\ + uint tid [[thread_position_in_threadgroup]],\ + uint tgid [[threadgroup_position_in_grid]],\ + uint tptg [[threads_per_threadgroup]]) + +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); +REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments); + +)METAL"; + +static std::pair, id> getCPLState(const std::string& fname) { + static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0); + return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname)); +} + +} //namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h new file mode 100644 index 00000000000000..fe9296cc0db79d --- /dev/null +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -0,0 +1,190 @@ +#pragma once +#include +#include +#include + +namespace at::native { +namespace mps { + +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kmaxThreadGroups = 32; +static constexpr int64_t kmaxTensors = 32; + +struct MetadataArguments { // the size of this struct must be less than 4 bytes + uint numels[kmaxTensors]; + uint threadgroup_to_tensor[kmaxThreadGroups]; + uint threadgroup_to_chunk[kmaxThreadGroups]; +}; + +template +static void multi_tensor_apply_for_fused_adam( + const std::string& kernel_name, + std::vector>& tensor_lists, + at::TensorList state_steps, + const double lr, + const double beta1, + const double beta2, + const double weight_decay, + const double eps, + const bool maximize + ) { + const auto num_tensors = tensor_lists[0].size(); + + if (num_tensors == 0) { + return; + } + + TORCH_CHECK( + tensor_lists.size() == depth, + "Number of tensor lists has to match the depth"); + for (const auto& d : c10::irange(depth)) { + TORCH_CHECK( + tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported"); + } + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + float lr_lv = lr; + float beta1_lv = beta1; + float beta2_lv = beta2; + float weight_decay_lv = weight_decay; + float eps_lv = eps; + uint8_t maximize_lv = maximize; + + // Remove comment for debugging + /* + mpsStream->addCompletedHandler(^(id cb) { + [cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) { + NSLog(@"MPSStream: %@", log); + } + ]; + }); + */ + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name); + + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]}); + + [computeEncoder setComputePipelineState:fusedOptimizerPSO]; + + // BufferIndex is the index in the kernel function + auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease]; + id tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int64_t tensor_loc = 0; + int64_t threadgroup_loc = 0; + MetadataArguments metadata_arguments; + + for (const auto tensor_index : c10::irange(num_tensors)) { + // short-circuit to avoid adding empty tensors to tensorListMeta + if (tensor_lists[0][tensor_index].numel() == 0) { + continue; + } + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + tensor_loc]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel(); + + tensor_loc++; + + const auto numel = tensor_lists[0][tensor_index].numel(); + const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0); + TORCH_CHECK(chunks > -1); + + for (const auto& chunk : c10::irange(chunks)) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + + threadgroup_loc++; + + const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1; + // Reach the maximum threadgroups per dispatch + const auto blocks_full = threadgroup_loc == kmaxThreadGroups; + + if (tensor_full || blocks_full){ + [computeEncoder setBuffer:tensorArgumentBuffer + offset:0 + atIndex:0]; + [computeEncoder setBytes:&metadata_arguments + length:sizeof(MetadataArguments) + atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // Reset + threadgroup_loc = 0; + if (chunk == chunks - 1) { + // last chunk + tensor_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + } else { + // reuse the current tensor since the current one isn't done. + metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1]; + + tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease]; + [tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + for (const auto& d : c10::irange(depth)) { + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index]) + offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size() + atIndex:d * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead]; + } + [tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index]) + offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size() + atIndex:depth * kmaxTensors + 0]; + [computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead]; + + tensor_loc = 1; + } + } + } + } + + if (threadgroup_loc != 0) { + + [computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0]; + [computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1]; + [computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2]; + [computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3]; + [computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4]; + [computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5]; + [computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6]; + [computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7]; + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + + getMPSProfiler().endProfileKernel(fusedOptimizerPSO); + + } + }); +} + +} // namespace mps +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7474e0bc55d8b8..b030141882c86e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15575,6 +15575,7 @@ dispatch: CPU: _fused_adam_kernel_cpu_ CUDA: _fused_adam_kernel_cuda_ + MPS: _fused_adam_kernel_mps_ autogen: _fused_adam, _fused_adam.out - func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15593,6 +15594,7 @@ dispatch: CPU: _fused_adamw_kernel_cpu_ CUDA: _fused_adamw_kernel_cuda_ + MPS: _fused_adamw_kernel_mps_ autogen: _fused_adamw, _fused_adamw.out - func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () diff --git a/test/test_mps.py b/test/test_mps.py index 311cf8245c4f3a..a97b8fb8d6b137 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -76,7 +76,6 @@ def mps_ops_grad_modifier(ops): XFAILLIST_GRAD = { # precision issues - 'digamma': [torch.float32], 'special.polygammaspecial_polygamma_n_0': [torch.float16], 'polygammapolygamma_n_0': [torch.float16], 'nn.functional.binary_cross_entropy': [torch.float16], @@ -95,7 +94,6 @@ def mps_ops_grad_modifier(ops): 'masked.scatter': [torch.float16, torch.float32], 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 'aminmax': [torch.float32, torch.float16], - 'polar': [torch.float32], # Correctness issues 'atanh': [torch.float32], @@ -569,7 +567,6 @@ def mps_ops_modifier(ops): 'special.ndtr': [torch.uint8], 'sqrt': [torch.uint8], 'sub': [torch.uint8], - 'tanh': [torch.uint8], 'trapezoid': [torch.uint8], 'trapz': [torch.uint8], 'true_divide': [torch.uint8], @@ -586,28 +583,13 @@ def mps_ops_modifier(ops): 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # cpu not giving nan for x/0.0 - 'atan2': [torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # inconsistency errors between cpu and mps, max seen atol is 2 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { - # Failure due to precision issues (still present on 13.3+) as well as non-standard behavior of - # cpu ops for the negative integers. - # Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], - # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 'tan': [torch.float32], 'cdist': [torch.float32], @@ -656,20 +638,6 @@ def mps_ops_modifier(ops): # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], - - # Failure due to precision issues as well as non-standard behavior of cpu ops for the - # negative integers. Example for torch.polygamma(1, tensor([-0.9, -1.0], dtype=torch.float32)): - # - CPU output: tensor([102.668, 1.129e+15]) - # - MPS output: tensor([102.6681, inf]) - # In the latter case, inf is probably correct (this is what scipy does). - 'polygamma': [torch.float32, torch.uint8], - 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8], - 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8], - 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8], - 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8], } MACOS_BEFORE_14_4_XFAILLIST = { diff --git a/test/test_optim.py b/test/test_optim.py index d61c33e2adcead..fb655ce36a5338 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,6 +32,7 @@ ) from torch.testing._internal.common_dtype import floating_types_and from torch.testing._internal.common_optimizers import ( + _get_device_type, _get_optim_inputs_including_global_cliquey_kwargs, optim_db, OptimizerErrorEnum, @@ -1004,7 +1005,6 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): self.assertLessEqual(mt_max_mem, expected_max_mem) - @onlyNativeDeviceTypes @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=floating_types_and( @@ -1013,10 +1013,15 @@ def test_peak_memory_foreach(self, device, dtype, optim_info): ), ) def test_fused_matches_forloop(self, device, dtype, optim_info): - if device not in optim_info.supports_fused_on: + if _get_device_type(device) not in optim_info.supports_fused_on: self.skipTest( f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" ) + if _get_device_type(device) == "mps" and dtype not in ( + torch.float16, + torch.float32, + ): + self.skipTest("MPS supports only torch.float16 and torch.float32") self._test_derived_optimizers(device, dtype, optim_info, "fused") @onlyNativeDeviceTypes @@ -1076,7 +1081,6 @@ def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info): ) self.assertEqual(params, params_c) - @onlyCUDA @parametrize("impl", ["fused", "capturable"]) @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], @@ -1100,8 +1104,15 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): ): # Capturable SGD/Adagrad does not exist self.skipTest("SGD does not currently support capturable") - if impl == "fused" and device not in optim_info.supports_fused_on: + if _get_device_type(device) == "cpu": + self.skipTest("Test is only for non-cpu devices") + elif ( + impl == "fused" + and _get_device_type(device) not in optim_info.supports_fused_on + ): self.skipTest(f"{device} is not supported for fused on {opt_name}") + elif impl == "capturable" and _get_device_type(device) == "mps": + self.skipTest("MPS does not support capturable") cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu") for optim_input in cpu_optim_inputs: @@ -1114,12 +1125,12 @@ def test_cpu_load_state_dict(self, device, dtype, impl, optim_info): # load optim_input.kwargs[impl] = True - param_cuda = param.clone().detach().to(device="cuda") - optimizer_cuda = optim_cls([param_cuda], **optim_input.kwargs) - optimizer_cuda.load_state_dict(optim_state_dict_cpu) - optimizer_cuda.zero_grad() - param_cuda.grad = torch.rand_like(param_cuda) - optimizer_cuda.step() + param_device = param.clone().detach().to(device=device) + optimizer_device = optim_cls([param_device], **optim_input.kwargs) + optimizer_device.load_state_dict(optim_state_dict_cpu) + optimizer_device.zero_grad() + param_device.grad = torch.rand_like(param_device) + optimizer_device.step() @optims(optim_db, dtypes=[torch.float32]) def test_param_groups_weight_decay(self, device, dtype, optim_info): diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 86785be4ed1795..fa7397e02b4240 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -309,6 +309,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _On the Convergence of Adam and Beyond: @@ -660,6 +662,10 @@ def _fused_adam( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 00931bed022727..20ab827552491e 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -310,6 +310,8 @@ def step(self, closure=None): {_capturable_doc} {_differentiable_doc} {_fused_doc} + .. Note:: + A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: @@ -662,6 +664,10 @@ def _fused_adamw( ), _, ) in grouped_tensors.items(): + if device.type == "mps": # type: ignore[union-attr] + assert found_inf is None and grad_scale is None + assert not isinstance(lr, Tensor) + device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 628bedad313dc1..b7d06e7dc80833 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -1232,7 +1232,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected floating point error between fused and compiled forloop DecorateInfo( @@ -1354,7 +1354,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamw, optim_error_inputs_func=optim_error_inputs_func_adamw, supported_impls=("foreach", "differentiable", "fused"), - supports_fused_on=("cpu", "cuda"), + supports_fused_on=("cpu", "cuda", "mps"), decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index bcc274579ad014..c3100d41b6c0f5 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -11,7 +11,7 @@ def _get_foreach_kernels_supported_devices() -> List[str]: def _get_fused_kernels_supported_devices() -> List[str]: r"""Return the device type list that supports fused kernels in optimizer.""" - return ["cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] + return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()] TensorListList: TypeAlias = List[List[Optional[Tensor]]] Indices: TypeAlias = List[int] From 5bc9835d64eb5592cb606252ccf19212872cefc7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 149/171] Revert "[dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428)" This reverts commit c52eda896eb3ec7f8d04b6321861f4c5614a40bb. Reverted https://github.com/pytorch/pytorch/pull/128428 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2329ab305e763c..dbcb259241fcbd 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index abbef02e63c682..b5b12435a931a7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,6 +2669,26 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", + "torch.nn.functional._adaptive_max_pool1d", + "torch.nn.functional._adaptive_max_pool2d", + "torch.nn.functional._adaptive_max_pool3d", + "torch.nn.functional._canonical_mask", + "torch.nn.functional._fractional_max_pool2d", + "torch.nn.functional._fractional_max_pool3d", + "torch.nn.functional._get_softmax_dim", + "torch.nn.functional._in_projection_packed", + "torch.nn.functional._in_projection", + "torch.nn.functional._is_integer", + "torch.nn.functional._max_pool1d", + "torch.nn.functional._max_pool2d", + "torch.nn.functional._max_pool3d", + "torch.nn.functional._mha_shape_check", + "torch.nn.functional._no_grad_embedding_renorm_", + "torch.nn.functional._none_or_dtype", + "torch.nn.functional._threshold", + "torch.nn.functional._unpool_output_size", + "torch.nn.functional._verify_batch_size", + "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2766,7 +2786,15 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", + "torch.nn.modules.activation._arg_requires_grad", + "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", + "torch.nn.modules.container._addindent", + "torch.nn.modules.transformer._detect_is_causal_mask", + "torch.nn.modules.transformer._generate_square_subsequent_mask", + "torch.nn.modules.transformer._get_activation_fn", + "torch.nn.modules.transformer._get_clones", + "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 1babeddbbf3a44318d13cf3b8afaac2a6d657115 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 150/171] Revert "[inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484)" This reverts commit 1f6e84fa6852805e15ddc9583c5f36c3a7f93df8. Reverted https://github.com/pytorch/pytorch/pull/128484 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- test/inductor/test_mkldnn_pattern_matcher.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index a80d7239876028..810c22d037c548 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,8 +37,7 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - # Use floats for min/max, otherwise they can get converted to symints - torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, + torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 44722c6b1085611e0f20917a76fcf3f8f2776e13 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 20:09:00 +0000 Subject: [PATCH 151/171] Revert "[dynamo][fsdp] Dont take unspecializedNNModuleVariable path for FSDP modules (#128453)" This reverts commit 2b28b107dbafeec18d1095a2002e79511aa241df. Reverted https://github.com/pytorch/pytorch/pull/128453 on behalf of https://github.com/anijain2305 due to luca saw bad compile time ([comment](https://github.com/pytorch/pytorch/pull/128453#issuecomment-2176877667)) --- torch/_dynamo/variables/builder.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index af91edb432c887..8a201410d6be3c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1164,11 +1164,7 @@ def wrap_module(self, value: torch.nn.Module): and not config.allow_rnn ): unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") - - # Dont take this path for FSDP - if not getattr( - value, "_is_fsdp_managed_module", None - ) and mutation_guard.is_dynamic_nn_module(value, self.tx.export): + if mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it self.install_guards(GuardBuilder.TYPE_MATCH) if ( From 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 18 Jun 2024 13:35:49 -0400 Subject: [PATCH 152/171] Backward support for unbind() with NJT (#128032) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032 Approved by: https://github.com/soulitzer --- test/test_nestedtensor.py | 19 +++++++++++++++++++ tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 +++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 4 ++++ torch/nested/_internal/ops.py | 11 +++++++++++ 5 files changed, 52 insertions(+), 1 deletion(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 50d6deea92911e..fa33a13ed495db 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5610,6 +5610,25 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + @dtypes(torch.float32, torch.double, torch.half) + def test_unbind_backward(self, device, dtype): + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 4, device=device), + torch.randn(5, 4, device=device), + torch.randn(3, 4, device=device), + ], + layout=torch.jagged, + requires_grad=True, + ) + + a, b, c = nt.unbind() + b.sum().backward() + + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) + instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a4f..02a3e6c518ad80 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 9d897c667c906f..f51c2f047f9351 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,6 +1014,23 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } +Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim) { + TORCH_INTERNAL_ASSERT( + dim == 0, "unbind_backward_nested_jagged() only supports dim=0") + auto grad_nt = at::zeros_like(self); + auto unbound_grads = grad_nt.unbind(); + for (int64_t i : c10::irange(static_cast(grads.size()))) { + if (grads[i].defined()) { + unbound_grads[i].copy_(static_cast(grads[i])); + } + } + + return grad_nt; +} + Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index dedff70be1ba34..ecf99bd098057b 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,6 +244,10 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); +at::Tensor unbind_backward_nested_jagged( + const variable_list& grads, + const Tensor& self, + int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd694712..8458f03717130c 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,6 +472,17 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) +@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") +def zero__default(func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + func(inp._values) + return inp + + @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From 4cc3fb5ee2296e1178cec710a945c99aa303170d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:38:22 -0700 Subject: [PATCH 153/171] Bump urllib3 from 2.2.1 to 2.2.2 in /tools/build/bazel (#128908) Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.2.1 to 2.2.2. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.2.1...2.2.2) --- updated-dependencies: - dependency-name: urllib3 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/build/bazel/requirements.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index cd95aeeec5c6fd..fea6221c9b7ca9 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -145,7 +145,7 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in pyyaml==6.0.1 \ --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ @@ -198,26 +198,26 @@ pyyaml==6.0.1 \ --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f - # via -r tools/build/bazel/requirements.in + # via -r requirements.in requests==2.32.2 \ --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c - # via -r tools/build/bazel/requirements.in + # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a - # via -r tools/build/bazel/requirements.in -urllib3==2.2.1 \ - --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ - --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 + # via -r requirements.in +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 # via requests # The following packages are considered to be unsafe in a requirements file: setuptools==69.5.1 \ --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in From 2227da44317f4ea836aaad96337b53533aed2770 Mon Sep 17 00:00:00 2001 From: Aaron Enye Shi Date: Tue, 18 Jun 2024 21:01:01 +0000 Subject: [PATCH 154/171] [Profiler] Clean up use_mtia to follow standard use_device instead (#126284) Summary: use_mtia should instead set use_device='mtia' similar to cuda, xpu, and privateuseone. Avoid an ever-growing list of use_* arguments. Since use_mtia is specific to FBCode, we don't need a deprecation warning. Test Plan: CI. Differential Revision: D57338005 Pulled By: aaronenyeshi Pull Request resolved: https://github.com/pytorch/pytorch/pull/126284 Approved by: https://github.com/fenypatel99 --- torch/autograd/profiler.py | 13 +++++++------ torch/profiler/profiler.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 0392a876984632..f847fc13ff8ad9 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -118,7 +118,7 @@ class profile: use_device (str, optional): Enables timing of device events. Adds approximately 4us of overhead to each tensor operation when use cuda. - The valid devices options are 'cuda', 'xpu' and 'privateuseone'. + The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'. record_shapes (bool, optional): If shapes recording is set, information about input dimensions will be collected. This allows one to see which @@ -205,7 +205,6 @@ def __init__( with_modules=False, use_kineto=False, use_cpu=True, - use_mtia=False, experimental_config=None, ): self.enabled: bool = enabled @@ -231,7 +230,6 @@ def __init__( self.with_stack = with_stack self.with_modules = with_modules self.use_cpu = use_cpu - self.use_mtia = use_mtia if experimental_config is None: experimental_config = _ExperimentalConfig() self.experimental_config = experimental_config @@ -246,7 +244,7 @@ def __init__( ), "Device-only events supported only with Kineto (use_kineto=True)" if self.use_device is not None: - VALID_DEVICE_OPTIONS = ["cuda", "xpu"] + VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"] if _get_privateuse1_backend_name() != "privateuseone": VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name()) if self.use_device not in VALID_DEVICE_OPTIONS: @@ -265,8 +263,6 @@ def __init__( self.kineto_activities = set() if self.use_cpu: self.kineto_activities.add(ProfilerActivity.CPU) - if self.use_mtia: - self.kineto_activities.add(ProfilerActivity.MTIA) self.profiler_kind = ProfilerState.KINETO if self.use_device == "cuda": @@ -280,6 +276,11 @@ def __init__( use_kineto and ProfilerActivity.XPU in _supported_activities() ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device == "mtia": + assert ( + use_kineto and ProfilerActivity.MTIA in _supported_activities() + ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices." + self.kineto_activities.add(ProfilerActivity.MTIA) elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f43dcc06de2099..2fd3ab9be6b80b 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -132,6 +132,8 @@ def __init__( self.use_device = "cuda" elif ProfilerActivity.XPU in self.activities: self.use_device = "xpu" + elif ProfilerActivity.MTIA in self.activities: + self.use_device = "mtia" elif ProfilerActivity.PrivateUse1 in self.activities: self.use_device = _get_privateuse1_backend_name() @@ -149,7 +151,6 @@ def prepare_trace(self): if self.profiler is None: self.profiler = prof.profile( use_cpu=(ProfilerActivity.CPU in self.activities), - use_mtia=(ProfilerActivity.MTIA in self.activities), use_device=self.use_device, record_shapes=self.record_shapes, with_flops=self.with_flops, From e47603a5495b33d59be0b770ac9b243877c993ad Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Jun 2024 06:51:41 -0700 Subject: [PATCH 155/171] Fix weight_norm decomposition behavior (#128956) By upcasting norm to float32 to align with CUDA and CPU behaviors https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/WeightNorm.cpp#L56-L59 Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060 ``` File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" AssertionError: 1 Operation: aten._weight_norm_interface.default ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128956 Approved by: https://github.com/albanD ghstack dependencies: #128955 --- torch/_decomp/decompositions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index dca552137ca6d3..42d1cb9a152703 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4773,8 +4773,10 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None): def _weight_norm_interface(v, g, dim=0): # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) - norm = v.norm(2, keep_dim, keepdim=True) - return v * (g / norm), norm + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm @register_decomposition(aten.isin) From cec31050b4609a4bbdcd332c823139666ad57224 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:43 +0800 Subject: [PATCH 156/171] [BE][Easy] enable UFMT for `torch/distributed/{tensor,_tensor}/` (#128868) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128868 Approved by: https://github.com/fegin --- .lintrunner.toml | 9 - .../distributed/_tensor/_collective_utils.py | 2 +- torch/distributed/_tensor/_dispatch.py | 2 +- torch/distributed/_tensor/_op_schema.py | 1 + torch/distributed/_tensor/_sharding_prop.py | 1 + torch/distributed/_tensor/_tp_conv.py | 1 + torch/distributed/_tensor/api.py | 1 - torch/distributed/_tensor/debug/__init__.py | 1 - .../distributed/_tensor/debug/_op_coverage.py | 1 - torch/distributed/_tensor/debug/comm_mode.py | 3 +- .../_tensor/debug/visualize_sharding.py | 1 - .../_tensor/examples/checkpoint_example.py | 2 - .../examples/comm_mode_features_example.py | 3 - .../examples/torchrec_sharding_example.py | 2 +- .../examples/visualize_sharding_example.py | 1 + .../_tensor/experimental/__init__.py | 1 + .../_tensor/experimental/attention.py | 1 + .../_tensor/experimental/local_map.py | 1 + torch/distributed/_tensor/ops/__init__.py | 8 +- .../distributed/_tensor/ops/basic_strategy.py | 2 - torch/distributed/_tensor/ops/conv_ops.py | 1 + .../distributed/_tensor/ops/embedding_ops.py | 3 +- .../_tensor/ops/experimental_ops.py | 12 +- torch/distributed/_tensor/ops/math_ops.py | 1 - torch/distributed/_tensor/ops/matrix_ops.py | 2 +- .../distributed/_tensor/ops/pointwise_ops.py | 2 - torch/distributed/_tensor/ops/random_ops.py | 1 + torch/distributed/_tensor/ops/tensor_ops.py | 1 - torch/distributed/_tensor/ops/view_ops.py | 3 +- torch/distributed/_tensor/placement_types.py | 1 - torch/distributed/_tensor/random.py | 1 - torch/distributed/tensor/parallel/__init__.py | 4 +- torch/distributed/tensor/parallel/_utils.py | 10 +- torch/distributed/tensor/parallel/api.py | 21 +- torch/distributed/tensor/parallel/ddp.py | 1 + torch/distributed/tensor/parallel/fsdp.py | 8 +- .../tensor/parallel/input_reshard.py | 13 +- torch/distributed/tensor/parallel/loss.py | 1 + torch/distributed/tensor/parallel/style.py | 226 ++++++++++++------ 39 files changed, 213 insertions(+), 143 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index a7bbdc884415ee..e3f1b58027c3ec 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1443,15 +1443,6 @@ exclude_patterns = [ 'torch/distributed/rpc/rref_proxy.py', 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', - 'torch/distributed/tensor/__init__.py', - 'torch/distributed/tensor/parallel/__init__.py', - 'torch/distributed/tensor/parallel/_utils.py', - 'torch/distributed/tensor/parallel/_view_with_dim_change.py', - 'torch/distributed/tensor/parallel/api.py', - 'torch/distributed/tensor/parallel/fsdp.py', - 'torch/distributed/tensor/parallel/input_reshard.py', - 'torch/distributed/tensor/parallel/multihead_attention_tp.py', - 'torch/distributed/tensor/parallel/style.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 4c1d18403666f2..15644ac798731e 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -3,7 +3,6 @@ import math from dataclasses import dataclass from functools import lru_cache - from typing import List, Optional import torch @@ -21,6 +20,7 @@ Work, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py index 1739243a5d3ba4..a659c54a3d932c 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/_tensor/_dispatch.py @@ -6,7 +6,6 @@ from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING import torch - import torch.distributed as dist import torch.distributed._tensor.api as dtensor import torch.distributed._tensor.random as random @@ -27,6 +26,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta from torch.distributed._tensor.random import is_rng_supported_mesh + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py index 071c2ac4748f13..6e6884f47306a6 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/_tensor/_op_schema.py @@ -8,6 +8,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.device_mesh import DeviceMesh + try: from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec except ImportError: diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/_tensor/_sharding_prop.py index 449cf6c23775af..8f1cabeb0c43cf 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/_tensor/_sharding_prop.py @@ -25,6 +25,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/_tensor/_tp_conv.py index d480e9d7f79ecc..cc6f1968e6ef99 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/_tensor/_tp_conv.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.distributed._tensor.api as dtensor + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 22f7e690022a96..e1c01040a9094a 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -5,7 +5,6 @@ from typing import Any, Callable, cast, Optional, Sequence, Tuple import torch - import torch.distributed._tensor._dispatch as op_dispatch import torch.distributed._tensor.random as random import torch.nn as nn diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/_tensor/debug/__init__.py index b7bde685fd1e1f..b70529f203e1de 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/_tensor/debug/__init__.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs from torch.distributed._tensor.api import DTensor - from torch.distributed._tensor.debug.comm_mode import CommDebugMode diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/_tensor/debug/_op_coverage.py index 4f54246332351e..214c4f003ff2d6 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/_tensor/debug/_op_coverage.py @@ -5,7 +5,6 @@ import torch import torch.fx import torch.nn as nn - from functorch.compile import make_boxed_func from torch._functorch.compilers import aot_module from torch._inductor.decomposition import select_decomp_table diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py index 5b69454828f3c6..0241c739fb7013 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/_tensor/debug/comm_mode.py @@ -5,16 +5,15 @@ import torch from torch.autograd.graph import register_multi_grad_hook from torch.distributed._tensor.api import DTensor - from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, ) from torch.utils._python_dispatch import TorchDispatchMode - from torch.utils._pytree import tree_flatten from torch.utils.module_tracker import ModuleTracker + funcol_native = torch.ops._c10d_functional funcol_py = torch.ops.c10d_functional funcol_autograd = torch.ops._c10d_functional_autograd diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 76cd8f3e920886..8eae86e5c0ab5e 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py @@ -5,7 +5,6 @@ from torch._prims_common import ShapeType from torch.distributed._tensor import DeviceMesh - from torch.distributed._tensor.placement_types import Placement, Shard diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index 1cb292f12c4140..1701e28ac2ca76 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -5,7 +5,6 @@ checkpoint save/load the model. """ import os - from typing import cast, List import torch @@ -13,7 +12,6 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F - from torch.distributed._tensor import ( DeviceMesh, distribute_module, diff --git a/torch/distributed/_tensor/examples/comm_mode_features_example.py b/torch/distributed/_tensor/examples/comm_mode_features_example.py index 106a5db7351078..93155687cf920c 100644 --- a/torch/distributed/_tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/_tensor/examples/comm_mode_features_example.py @@ -1,16 +1,13 @@ import os import torch - from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.debug import CommDebugMode - from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, RowwiseParallel, ) - from torch.testing._internal.distributed._tensor.common_dtensor import ( MLPModule, MLPStacked, diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/_tensor/examples/torchrec_sharding_example.py index 3e6c63dd18eb99..33f8c7017f5be4 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/_tensor/examples/torchrec_sharding_example.py @@ -9,7 +9,6 @@ from typing import List, TYPE_CHECKING import torch - from torch.distributed._tensor import ( DeviceMesh, DTensor, @@ -24,6 +23,7 @@ TensorStorageMetadata, ) + if TYPE_CHECKING: from torch.distributed._tensor.placement_types import Placement diff --git a/torch/distributed/_tensor/examples/visualize_sharding_example.py b/torch/distributed/_tensor/examples/visualize_sharding_example.py index 6e295e147b38be..0f839688915910 100644 --- a/torch/distributed/_tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/_tensor/examples/visualize_sharding_example.py @@ -4,6 +4,7 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding + world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py index 2dd21605ffcc5e..bee73667e1eaf3 100644 --- a/torch/distributed/_tensor/experimental/__init__.py +++ b/torch/distributed/_tensor/experimental/__init__.py @@ -5,6 +5,7 @@ from torch.distributed._tensor.api import DTensor from torch.distributed._tensor.experimental.local_map import local_map + __all__ = ["local_map", "implicit_replication"] diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/_tensor/experimental/attention.py index eb7703a96ba5fb..b7738cb2dee543 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/_tensor/experimental/attention.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel.style import ParallelStyle + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/experimental/local_map.py b/torch/distributed/_tensor/experimental/local_map.py index 0fc6ce96e6e02a..60d1796fdec4c3 100644 --- a/torch/distributed/_tensor/experimental/local_map.py +++ b/torch/distributed/_tensor/experimental/local_map.py @@ -7,6 +7,7 @@ from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor.placement_types import Placement + try: from torch.utils import _cxx_pytree as pytree except ImportError: diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/_tensor/ops/__init__.py index d19fdfa50cb704..eaccc8aa8d3f61 100644 --- a/torch/distributed/_tensor/ops/__init__.py +++ b/torch/distributed/_tensor/ops/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +from .conv_ops import * # noqa: F403 from .embedding_ops import * # noqa: F403 -from .matrix_ops import * # noqa: F403 +from .experimental_ops import * # noqa: F403 from .math_ops import * # noqa: F403 -from .tensor_ops import * # noqa: F403 +from .matrix_ops import * # noqa: F403 from .pointwise_ops import * # noqa: F403 from .random_ops import * # noqa: F403 +from .tensor_ops import * # noqa: F403 from .view_ops import * # noqa: F403 -from .conv_ops import * # noqa: F403 -from .experimental_ops import * # noqa: F403 diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index cc28cc19d370a7..97dd43b1524dc8 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -1,6 +1,5 @@ import itertools from dataclasses import dataclass - from typing import List, Set, Tuple from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy @@ -11,7 +10,6 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh diff --git a/torch/distributed/_tensor/ops/conv_ops.py b/torch/distributed/_tensor/ops/conv_ops.py index f466a13aa46373..24e75593064eed 100644 --- a/torch/distributed/_tensor/ops/conv_ops.py +++ b/torch/distributed/_tensor/ops/conv_ops.py @@ -7,6 +7,7 @@ from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py index 6f8cc8c67851ef..5af79562adcb25 100644 --- a/torch/distributed/_tensor/ops/embedding_ops.py +++ b/torch/distributed/_tensor/ops/embedding_ops.py @@ -11,16 +11,15 @@ expand_to_full_mesh_op_strategy, register_op_strategy, ) - from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/experimental_ops.py b/torch/distributed/_tensor/ops/experimental_ops.py index 546945acd6220a..6d6967d4ea8d15 100644 --- a/torch/distributed/_tensor/ops/experimental_ops.py +++ b/torch/distributed/_tensor/ops/experimental_ops.py @@ -2,19 +2,21 @@ # implement matrix related ops for distributed tensor from typing import List -try: - import numpy as np -except ModuleNotFoundError: - np = None # type: ignore[assignment] - import torch from torch.distributed._tensor._op_schema import OpSchema, OutputSharding from torch.distributed._tensor.ops.utils import register_prop_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + aten = torch.ops.aten +try: + import numpy as np +except ModuleNotFoundError: + np = None # type: ignore[assignment] + + @register_prop_rule(aten.slice_backward.default) def slice_backward_rules(op_schema: OpSchema) -> OutputSharding: grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py index 377c50dffa13e8..412c566253ab1b 100644 --- a/torch/distributed/_tensor/ops/math_ops.py +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -6,7 +6,6 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch - from torch.distributed._tensor._op_schema import ( OpSchema, OpStrategy, diff --git a/torch/distributed/_tensor/ops/matrix_ops.py b/torch/distributed/_tensor/ops/matrix_ops.py index 15f00af670d273..128a73a59ffecd 100644 --- a/torch/distributed/_tensor/ops/matrix_ops.py +++ b/torch/distributed/_tensor/ops/matrix_ops.py @@ -19,9 +19,9 @@ Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py index ab80f783cf5b38..96bfb808c10068 100644 --- a/torch/distributed/_tensor/ops/pointwise_ops.py +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -2,7 +2,6 @@ from typing import List, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, @@ -13,7 +12,6 @@ StrategyType, TupleStrategy, ) - from torch.distributed._tensor.ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, diff --git a/torch/distributed/_tensor/ops/random_ops.py b/torch/distributed/_tensor/ops/random_ops.py index 390dc419ecd784..d4b533aae09ac5 100644 --- a/torch/distributed/_tensor/ops/random_ops.py +++ b/torch/distributed/_tensor/ops/random_ops.py @@ -9,6 +9,7 @@ from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py index d2feb19ba2f95e..a91d6261c51dc4 100644 --- a/torch/distributed/_tensor/ops/tensor_ops.py +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -3,7 +3,6 @@ from typing import cast, List, Optional, Sequence, Tuple import torch - from torch.distributed._tensor._op_schema import ( _is_inplace_op, OpSchema, diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py index 7161988adf25ca..ea088b7377a9bc 100644 --- a/torch/distributed/_tensor/ops/view_ops.py +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -15,7 +15,6 @@ ) import torch - from torch import Tensor from torch.distributed._tensor._op_schema import ( OpSchema, @@ -32,10 +31,10 @@ prod, register_op_strategy, ) - from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten Shape = Tuple[int, ...] diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 31e280c2f5b8b4..352e12640bd744 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -6,7 +6,6 @@ import torch import torch.distributed._functional_collectives as funcol - from torch.distributed._tensor._collective_utils import ( fill_empty_tensor_to_shards, mesh_broadcast, diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/_tensor/random.py index ed331736c5ce49..3e43a9119ac202 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/_tensor/random.py @@ -6,7 +6,6 @@ import torch import torch.distributed as dist - from torch import Tensor from torch.distributed._tensor.placement_types import DTensorSpec, Shard from torch.distributed.device_mesh import _get_device_handle, DeviceMesh diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index 990550414ca47a..9fe378c51b0d40 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from torch.distributed.tensor.parallel.api import parallelize_module - from torch.distributed.tensor.parallel.loss import loss_parallel from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -11,6 +10,7 @@ SequenceParallel, ) + __all__ = [ "ColwiseParallel", "ParallelStyle", @@ -19,5 +19,5 @@ "RowwiseParallel", "SequenceParallel", "parallelize_module", - "loss_parallel" + "loss_parallel", ] diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 394fde457bb21f..3f47ec6f1ef34e 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -5,12 +5,16 @@ from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources + + try: from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling except Exception: + def is_torchdynamo_compiling(): # type: ignore[misc] return False + LayoutsType = Union[Placement, Tuple[Placement, ...]] @@ -46,8 +50,10 @@ def _validate_tp_mesh_dim( is valid, `False` otherwise. """ if device_mesh.ndim > 1: - raise ValueError(f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]') + raise ValueError( + f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' + ) parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) if parent_mesh: diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index f78e9712d304bc..e0fc4d2ef2b725 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -1,21 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import Dict, Union from fnmatch import fnmatch +from typing import Dict, Union import torch import torch.distributed._tensor.random as random import torch.nn as nn -from torch.distributed._tensor import ( - DeviceMesh, -) +from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.random import ( is_rng_supported_mesh, TensorParallelRNGTracker, ) from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim -from torch.distributed.tensor.parallel.style import ( - ParallelStyle, -) +from torch.distributed.tensor.parallel.style import ParallelStyle __all__ = [ @@ -98,14 +94,19 @@ def parallelize_module( # type: ignore[return] atom = path_splits.pop(0) matched_children = filter( # `t[0]` is child name - lambda t: fnmatch(t[0], atom), module.named_children() + lambda t: fnmatch(t[0], atom), + module.named_children(), ) # apply the plan to all matched submodules for _, submodule in matched_children: if path_splits: # we haven't reached the leaf, apply in dict style - leaf_path = ".".join(path_splits) # rest of the path after `atom` - parallelize_module(submodule, device_mesh, {leaf_path: parallelize_style}) + leaf_path = ".".join( + path_splits + ) # rest of the path after `atom` + parallelize_module( + submodule, device_mesh, {leaf_path: parallelize_style} + ) else: # otherwise, directly apply style to this submodule parallelize_module(submodule, device_mesh, parallelize_style) diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index baa9d638037d3f..6c4d6f80167555 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -7,6 +7,7 @@ _unflatten_tensor, ) + __all__ = [] # type: ignore[var-annotated] diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index c38771ae86e2b4..df51efaf87f54b 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -4,7 +4,6 @@ import torch import torch.distributed as dist - import torch.distributed._shard.sharding_spec as shard_spec import torch.distributed.distributed_c10d as c10d from torch.distributed._shard.sharded_tensor import ( @@ -13,12 +12,10 @@ ShardedTensorMetadata, TensorProperties, ) - from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor @@ -28,6 +25,7 @@ _unflatten_tensor, ) + __all__ = ["DTensorExtensions"] @@ -245,7 +243,6 @@ def _chunk_dtensor( # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): - # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. # TP is the inner dimension and FSDP is the outer dimension. # Therefore, shard placements for tensor is (Shard(0), Replicate()). @@ -324,6 +321,7 @@ class DTensorExtensions(FSDPExtensions): This is the implementation for FSDPExtensions defined in https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py """ + def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None @@ -352,7 +350,7 @@ def post_unflatten_transform( tensor, param_extension, device_handle=self.device_handle, - compute_stream=self.compute_stream + compute_stream=self.compute_stream, ) _set_fsdp_flattened(result) return result diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 3ea97846e313a4..4e7af55d32c356 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -5,6 +5,7 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard + __all__ = [ "input_reshard", ] @@ -49,7 +50,9 @@ def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> N nonlocal cx cx = saved_tensor_hooks # type: ignore[name-defined] - def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any: + def input_reshard_backward_hook( + _: torch.nn.Module, _i: Tuple[Any, ...], _o: Any + ) -> Any: nonlocal cx cx.__exit__() # type: ignore[name-defined, union-attr] @@ -60,7 +63,9 @@ def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any return module -def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 +def _pack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor +) -> Any: # noqa: D401 """Hook function called after FWD to shard input.""" if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) @@ -78,7 +83,9 @@ def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> return x -def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 +def _unpack_hook_tp( + mesh: DeviceMesh, input_reshard_dim: int, x: Any +) -> torch.Tensor: # noqa: D401 """Hook function called before activation recomputing in BWD to restore input.""" if ( isinstance(x, DTensor) diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index f2776c5123b47e..a51d14b0efbd5c 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -18,6 +18,7 @@ from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta from torch.distributed.device_mesh import DeviceMesh + aten = torch.ops.aten diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index a4f4d4de0b985a..42437a7084758c 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -1,12 +1,20 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod -from typing import Optional, Union, Tuple, Dict, Any from functools import partial +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Placement, + Replicate, + Shard, +) __all__ = [ @@ -74,29 +82,35 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Replicate(), ) - self.output_layouts = (output_layouts or Shard(-1), ) + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) # colwise linear runtime sharding (desired sharding): # 1. requires replicate input # 2. shard output on last dim - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): # TODO: figure out dynamo support for instance method and switch this to instance method # annotate module input placements/sharding with input_layouts input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): @@ -104,17 +118,13 @@ def _partition_linear_fn(self, name, module, device_mesh): # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) def _partition_embedding_fn(self, name, module, device_mesh): # colwise shard embedding.weight is straight forward as Shard(1) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(1)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(1)])) module.register_parameter(name, dist_param) @staticmethod @@ -131,14 +141,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn else: - raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "ColwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -180,41 +196,49 @@ def __init__( *, input_layouts: Optional[Placement] = None, output_layouts: Optional[Placement] = None, - use_local_output: bool = True + use_local_output: bool = True, ): super().__init__() - self.input_layouts = (input_layouts or Shard(-1), ) - self.output_layouts = (output_layouts or Replicate(), ) + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): input_tensor = inputs[0] if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) return input_tensor def _partition_linear_fn(self, name, module, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where # weight would become Shard(0) - module.register_parameter("weight", nn.Parameter( - distribute_tensor(module.weight, device_mesh, [Shard(1)]) - )) + module.register_parameter( + "weight", + nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(1)])), + ) if module.bias is not None: - module.register_parameter("bias", nn.Parameter( - distribute_tensor(module.bias, device_mesh, [Replicate()]) - )) + module.register_parameter( + "bias", + nn.Parameter( + distribute_tensor(module.bias, device_mesh, [Replicate()]) + ), + ) def _partition_embedding_fn(self, name, module, device_mesh): # rowwise shard embedding.weight is Shard(0) for name, param in module.named_parameters(): - dist_param = nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) module.register_parameter(name, dist_param) @staticmethod @@ -231,20 +255,26 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if isinstance(module, nn.Linear): partition_fn = self._partition_linear_fn # rowwise linear runtime sharding requires input tensor shard on last dim - self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), ) + self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),) elif isinstance(module, nn.Embedding): partition_fn = self._partition_embedding_fn # rowwise embedding runtime sharding requires input tensor replicated - self.desired_input_layouts = (Replicate(), ) + self.desired_input_layouts = (Replicate(),) else: - raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!") + raise NotImplementedError( + "RowwiseParallel currently only support nn.Linear and nn.Embedding!" + ) return distribute_module( module, device_mesh, partition_fn, - partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), - partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), ) @@ -287,17 +317,15 @@ class SequenceParallel(ParallelStyle): inits for the weights on those modules, you need to broadcast the weights before/after parallelizing to ensure that they are replicated. """ - def __init__( - self, - *, - sequence_dim: int = 1, - use_local_output: bool = False - ): + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): super().__init__() self.sequence_dim = sequence_dim self.use_local_output = use_local_output - def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): for p_name, param in module.named_parameters(): # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow # us to simply just use from_local @@ -312,9 +340,13 @@ def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): if isinstance(input_tensor, DTensor): return inputs elif isinstance(input_tensor, torch.Tensor): - return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False) + return DTensor.from_local( + input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False + ) else: - raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}") + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) @staticmethod def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): @@ -380,32 +412,43 @@ def __init__( self, *, input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, - desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, Tuple[Optional[Placement]]] + ] = None, input_kwarg_layouts: Optional[Dict[str, Placement]] = None, desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, - use_local_output: bool = False + use_local_output: bool = False, ): - self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts - self.desired_input_layouts = \ - (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) self.use_local_output = use_local_output if self.input_layouts is not None: - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - assert len(self.input_layouts) == len(self.desired_input_layouts), \ - "input_layouts and desired_input_layouts should have same length!" + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + assert len(self.input_layouts) == len( + self.desired_input_layouts + ), "input_layouts and desired_input_layouts should have same length!" self.with_kwargs = input_kwarg_layouts is not None self.input_kwarg_layouts = input_kwarg_layouts or {} self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} if self.with_kwargs: - assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \ - "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" def _prepare_input_arg( self, input: Any, mesh: DeviceMesh, input_layout: Optional[Placement], - desired_layout: Optional[Placement] + desired_layout: Optional[Placement], ): if input_layout is not None: if isinstance(input, DTensor): @@ -413,8 +456,12 @@ def _prepare_input_arg( # assert inp.placements[0] == input_layout dt_inp = input else: - assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!" - dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False) + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) if desired_layout is not None and input_layout != desired_layout: dt_inp = dt_inp.redistribute(placements=(desired_layout,)) @@ -432,9 +479,15 @@ def _prepare_input_fn(self, inputs, device_mesh): if len(inputs) != len(self.input_layouts): raise ValueError("module inputs and input_layouts should have same length!") - assert self.desired_input_layouts is not None, "desired module inputs should not be None!" - for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts): - prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)) + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) return tuple(prepared_inputs) def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): @@ -445,15 +498,19 @@ def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): input_layout = self.input_kwarg_layouts.get(kwarg_key) desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) - prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout) + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) return (prepared_arg_inputs, prepared_kwarg_inputs) def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: if self.with_kwargs: module.register_forward_pre_hook( - lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh), - with_kwargs=True + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, ) # type: ignore[misc] else: module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg] @@ -497,38 +554,55 @@ class PrepareModuleOutput(ParallelStyle): >>> ) >>> ) """ + def __init__( self, *, output_layouts: Union[Placement, Tuple[Placement]], desired_output_layouts: Union[Placement, Tuple[Placement]], - use_local_output: bool = True + use_local_output: bool = True, ): - self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts - self.desired_output_layouts = \ - (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) self.use_local_output = use_local_output - assert len(self.output_layouts) == len(self.desired_output_layouts), \ - "output_layouts and desired_output_layouts should have same length!" + assert len(self.output_layouts) == len( + self.desired_output_layouts + ), "output_layouts and desired_output_layouts should have same length!" def _prepare_out_fn(self, outputs, device_mesh): prepared_outputs = [] if not isinstance(outputs, tuple): outputs = (outputs,) if len(outputs) != len(self.output_layouts): - raise ValueError("module outputs and output_layouts should have same length!") - for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): if out_layout is not None: if isinstance(out, DTensor): # TODO: re-enable the check once we fix the compile path # assert out.placements[0] == out_layout dt_out = out else: - dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False) + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) if out_layout != desired_out_layout: dt_out = dt_out.redistribute(placements=(desired_out_layout,)) - prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) else: prepared_outputs.append(out) if len(prepared_outputs) == 1: From 3b798df853444d66077ffa846f5682e621b07388 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:44 +0800 Subject: [PATCH 157/171] [BE][Easy] enable UFMT for `torch/distributed/{fsdp,optim,rpc}/` (#128869) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128869 Approved by: https://github.com/fegin ghstack dependencies: #128868 --- .lintrunner.toml | 27 ---- torch/distributed/fsdp/__init__.py | 1 + torch/distributed/fsdp/_common_utils.py | 2 + torch/distributed/fsdp/_debug_utils.py | 1 + torch/distributed/fsdp/_flat_param.py | 1 + torch/distributed/fsdp/_init_utils.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 1 + torch/distributed/fsdp/_runtime_utils.py | 1 + torch/distributed/fsdp/_state_dict_utils.py | 3 - .../distributed/fsdp/_unshard_param_utils.py | 1 + torch/distributed/fsdp/_wrap_utils.py | 1 - torch/distributed/fsdp/api.py | 2 +- .../fsdp/fully_sharded_data_parallel.py | 2 +- torch/distributed/fsdp/sharded_grad_scaler.py | 1 + torch/distributed/fsdp/wrap.py | 1 + torch/distributed/optim/__init__.py | 10 +- .../optim/apply_optimizer_in_backward.py | 10 +- .../distributed/optim/functional_adadelta.py | 5 +- torch/distributed/optim/functional_adagrad.py | 3 +- torch/distributed/optim/functional_adam.py | 3 +- torch/distributed/optim/functional_adamax.py | 3 +- torch/distributed/optim/functional_adamw.py | 3 +- torch/distributed/optim/functional_rmsprop.py | 3 +- torch/distributed/optim/functional_rprop.py | 3 +- torch/distributed/optim/functional_sgd.py | 3 +- torch/distributed/optim/named_optimizer.py | 13 +- torch/distributed/optim/optimizer.py | 5 +- torch/distributed/optim/utils.py | 2 + .../optim/zero_redundancy_optimizer.py | 37 +++--- torch/distributed/rpc/__init__.py | 69 +++++----- torch/distributed/rpc/_testing/__init__.py | 5 +- .../_testing/faulty_agent_backend_registry.py | 11 +- torch/distributed/rpc/_utils.py | 19 ++- torch/distributed/rpc/api.py | 118 ++++++++++-------- torch/distributed/rpc/backend_registry.py | 99 ++++++++++----- torch/distributed/rpc/constants.py | 3 +- torch/distributed/rpc/functions.py | 2 + torch/distributed/rpc/internal.py | 5 +- torch/distributed/rpc/options.py | 2 + torch/distributed/rpc/rref_proxy.py | 17 ++- .../rpc/server_process_global_profiler.py | 13 +- 41 files changed, 300 insertions(+), 213 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e3f1b58027c3ec..99c04cac4fbb39 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1413,35 +1413,8 @@ exclude_patterns = [ 'torch/distributed/nn/jit/instantiator.py', 'torch/distributed/nn/jit/templates/__init__.py', 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/optim/__init__.py', - 'torch/distributed/optim/apply_optimizer_in_backward.py', - 'torch/distributed/optim/functional_adadelta.py', - 'torch/distributed/optim/functional_adagrad.py', - 'torch/distributed/optim/functional_adam.py', - 'torch/distributed/optim/functional_adamax.py', - 'torch/distributed/optim/functional_adamw.py', - 'torch/distributed/optim/functional_rmsprop.py', - 'torch/distributed/optim/functional_rprop.py', - 'torch/distributed/optim/functional_sgd.py', - 'torch/distributed/optim/named_optimizer.py', - 'torch/distributed/optim/optimizer.py', - 'torch/distributed/optim/post_localSGD_optimizer.py', - 'torch/distributed/optim/utils.py', - 'torch/distributed/optim/zero_redundancy_optimizer.py', 'torch/distributed/remote_device.py', 'torch/distributed/rendezvous.py', - 'torch/distributed/rpc/__init__.py', - 'torch/distributed/rpc/_testing/__init__.py', - 'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py', - 'torch/distributed/rpc/_utils.py', - 'torch/distributed/rpc/api.py', - 'torch/distributed/rpc/backend_registry.py', - 'torch/distributed/rpc/constants.py', - 'torch/distributed/rpc/functions.py', - 'torch/distributed/rpc/internal.py', - 'torch/distributed/rpc/options.py', - 'torch/distributed/rpc/rref_proxy.py', - 'torch/distributed/rpc/server_process_global_profiler.py', 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index d887730f442f6d..6180dbb3df299e 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -18,6 +18,7 @@ StateDictType, ) + __all__ = [ "BackwardPrefetch", "CPUOffload", diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index aae2405d0bb50d..10d0f821265119 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -44,9 +44,11 @@ StateDictType, ) + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + from ._flat_param import FlatParamHandle FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 523330e5580dfd..163d9a045b68ea 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -15,6 +15,7 @@ clean_tensor_name, ) + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 816b91433063af..8bc975dc72fd5a 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -50,6 +50,7 @@ FSDPExtensions, ) + __all__ = [ "FlatParameter", "FlatParamHandle", diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c8b58091bf89b5..aaeedf22397a42 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -58,9 +58,9 @@ from torch.distributed.fsdp.wrap import _Policy from torch.distributed.tensor.parallel.fsdp import DTensorExtensions from torch.distributed.utils import _sync_params_and_buffers - from torch.utils._python_dispatch import is_traceable_wrapper_subclass + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 54f800a168653a..4cfe761769a3b9 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -55,6 +55,7 @@ ) from torch.utils._pytree import tree_map_only + if TYPE_CHECKING: from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 833c1d45697aef..f84e7dd3e5055e 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -39,6 +39,7 @@ ) from torch.utils import _pytree as pytree + logger = logging.getLogger(__name__) # Do not include "process_group" to enable hybrid shard and MoE cases diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 797a0116587bb3..815cfb2dd4a1ff 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -17,9 +17,7 @@ import torch import torch.distributed as dist - import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper - import torch.nn as nn import torch.nn.functional as F from torch.distributed._shard.sharded_tensor import ( @@ -29,7 +27,6 @@ ) from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources - from torch.distributed.fsdp._common_utils import ( _FSDPState, _get_module_fsdp_state_if_fully_sharded_module, diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 435193a88703a1..4143d2928c8b83 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -26,6 +26,7 @@ from ._flat_param import FlatParamHandle + FLAT_PARAM = "_flat_param" diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 84cdf250d8ae1e..895bcbd8e967b4 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -11,7 +11,6 @@ _get_module_fsdp_state, _override_module_mixed_precision, ) - from torch.distributed.fsdp.wrap import ( _construct_wrap_fn, _or_policy, diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 0272ee0c57c9fc..f2e4bdb7ea0231 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from enum import auto, Enum - from typing import Optional, Sequence, Type import torch from torch.nn.modules.batchnorm import _BatchNorm + __all__ = [ "ShardingStrategy", "BackwardPrefetch", diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9edd057a8f371e..1567bb973b22a6 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -85,8 +85,8 @@ StateDictType, ) from torch.distributed.utils import _p_assert -from ._flat_param import FlatParameter, FlatParamHandle +from ._flat_param import FlatParameter, FlatParamHandle from ._optim_utils import ( _flatten_optim_state_dict, _get_param_id_to_param_from_optim_input, diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 3487e01263c719..7c1b2f83528683 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -8,6 +8,7 @@ from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState from torch.distributed.distributed_c10d import ProcessGroup + logger = logging.getLogger(__name__) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index acb5a6f1f642ad..f8604bbb1bb048 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -24,6 +24,7 @@ import torch.nn as nn + __all__ = [ "always_wrap_policy", "lambda_auto_wrap_policy", diff --git a/torch/distributed/optim/__init__.py b/torch/distributed/optim/__init__.py index fe33265fd532f4..924b993ec8414b 100644 --- a/torch/distributed/optim/__init__.py +++ b/torch/distributed/optim/__init__.py @@ -15,7 +15,6 @@ _get_in_backward_optimizers, ) from .functional_adadelta import _FunctionalAdadelta - from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam from .functional_adamax import _FunctionalAdamax @@ -26,6 +25,7 @@ from .named_optimizer import _NamedOptimizer from .utils import as_functional_optim + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -44,4 +44,10 @@ from .post_localSGD_optimizer import PostLocalSGDOptimizer from .zero_redundancy_optimizer import ZeroRedundancyOptimizer -__all__ = ["as_functional_optim", "DistributedOptimizer", "PostLocalSGDOptimizer", "ZeroRedundancyOptimizer"] + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 6bd182cca5736f..36f679f4eba49b 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -2,6 +2,7 @@ import torch + __all__: List[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter @@ -11,6 +12,7 @@ param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + @no_type_check def _apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], @@ -48,9 +50,7 @@ def _apply_optimizer_in_backward( # have their registered optimizer(s) applied. """ - torch._C._log_api_usage_once( - "torch.distributed.optim.apply_optimizer_in_backward" - ) + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") @no_type_check def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: @@ -62,7 +62,9 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: # Don't create a new acc_grad if we already have one # i.e. for shared parameters or attaching multiple optimizers to a param. if param not in param_to_acc_grad_map: - param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0] + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] optimizer = optimizer_class([param], **optimizer_kwargs) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index bc5f7c63dd1751..3ad51348b6afab 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adadelta Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, @@ -102,5 +103,5 @@ def step(self, gradients: List[Optional[Tensor]]): weight_decay=weight_decay, foreach=self.foreach, maximize=self.maximize, - has_complex=has_complex + has_complex=has_complex, ) diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 93a1fe2b2240df..67f7328489ed21 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adagrad Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 34868d23d8a53c..3ed271765170c6 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adam Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index 32bce65dfe1f50..8f1fdc0ccc02be 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Adamax Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 43addd0508221f..d3f1f80e9209bd 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional AdamW Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 851119c8600c0e..7a03e8e9f462f8 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional RMSprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 60742bc68896fc..615015a95a316b 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional Rprop Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 3a8176e877057c..32381855db6b55 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -3,11 +3,12 @@ import torch import torch.optim._functional as F - from torch import Tensor + __all__: List[str] = [] + # Define a TorchScript compatible Functional SGD Optimizer # where we use these optimizer in a functional way. # Instead of using the `param.grad` when updating parameters, diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 9e1e5377873d10..8e0b539b148264 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,9 +1,18 @@ # mypy: allow-untyped-defs import logging import warnings - from copy import deepcopy -from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Mapping, + Optional, + overload, + Union, +) import torch import torch.nn as nn diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index f2eca606c02611..65df14770c21c4 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import logging - from collections import defaultdict from threading import Lock from typing import List, Optional @@ -12,8 +11,10 @@ import torch.nn as nn from torch import Tensor from torch.distributed.rpc import RRef + from .utils import functional_optim_map + __all__ = ["DistributedOptimizer"] logger = logging.getLogger(__name__) @@ -205,7 +206,7 @@ def __init__(self, optimizer_class, params_rref, *args, **kwargs): "(i.e. Distributed Model Parallel training on CPU) due to the Python's " "Global Interpreter Lock (GIL). Please file an issue if you need this " "optimizer in TorchScript. ", - optimizer_class + optimizer_class, ) optimizer_new_func = _new_local_optimizer diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index af2220ca557493..d2c75eee7e39bc 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -2,6 +2,7 @@ from typing import Type from torch import optim + from .functional_adadelta import _FunctionalAdadelta from .functional_adagrad import _FunctionalAdagrad from .functional_adam import _FunctionalAdam @@ -11,6 +12,7 @@ from .functional_rprop import _FunctionalRprop from .functional_sgd import _FunctionalSGD + # dict to map a user passed in optimizer_class to a functional # optimizer class if we have already defined inside the # distributed.optim package, this is so that we hide the diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 8a3be3b0181536..f664d11afb79c0 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -20,11 +20,12 @@ from torch.optim import Optimizer -logger = logging.getLogger(__name__) - __all__ = ["ZeroRedundancyOptimizer"] +logger = logging.getLogger(__name__) + + # Credits: classy_vision/generic/distributed_util.py def _recursive_copy_to_device( value: Any, @@ -925,9 +926,9 @@ def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: mapping bucket indices to :class:`_DDPBucketAssignment` s for each rank. """ - assert self._overlap_with_ddp, ( - "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" - ) + assert ( + self._overlap_with_ddp + ), "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" if len(self._bucket_assignments_per_rank_cache) > 0: return self._bucket_assignments_per_rank_cache @@ -1074,9 +1075,9 @@ def _local_step( "Specifying `gradients` should not " "be used when `overlap_with_ddp=False`" ) - assert closure is None, ( - "`closure` is not supported when using a local functional optimizer" - ) + assert ( + closure is None + ), "`closure` is not supported when using a local functional optimizer" loss = self.optim.step(gradients=gradients) # Sync any updated attributes in the local optimizer to the exposed @@ -1504,7 +1505,7 @@ def _init_local_optimizer(self) -> None: "%s does not support the argument " "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " "error due to an empty parameter list", - self._optim_constructor + self._optim_constructor, ) self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef] @@ -1515,17 +1516,16 @@ def _init_local_optimizer(self) -> None: self._bucket_assignments_per_rank[self.global_rank] ) logger.info( - "rank %s with %s parameters " - "across %s buckets", - self.global_rank, local_numel, num_assigned_buckets + "rank %s with %s parameters " "across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, ) if self.global_rank == 0: logger.info( - "%s DDP " - "buckets and " - "%s bucket " - "assignments", - len(self._overlap_info.params_per_bucket), self._overlap_info.num_bucket_assignments + "%s DDP " "buckets and " "%s bucket " "assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, ) else: # NOTE: Passing `param_groups` into the local optimizer constructor @@ -1640,7 +1640,8 @@ def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: "Using the functional optimizer %s " "instead of %s since " "`overlap_with_ddp=True`", - optim_constructor, optimizer_class + optim_constructor, + optimizer_class, ) return optim_constructor else: diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 581433d220c63e..6c6608a2a773f3 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,22 +1,25 @@ # mypy: allow-untyped-defs -from datetime import timedelta import logging import os import threading import warnings +from datetime import timedelta from typing import Generator, Tuple from urllib.parse import urlparse import torch import torch.distributed as dist + +__all__ = ["is_available"] + + logger = logging.getLogger(__name__) _init_counter = 0 _init_counter_lock = threading.Lock() -__all__ = ["is_available"] def is_available() -> bool: return hasattr(torch._C, "_rpc_init") @@ -27,54 +30,51 @@ def is_available() -> bool: if is_available(): + import numbers + + import torch.distributed.autograd as dist_autograd from torch._C._distributed_c10d import Store - from torch._C._distributed_rpc import ( + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, _disable_jit_rref_pickle, - _enable_jit_rref_pickle, _disable_server_process_global_profiler, + _enable_jit_rref_pickle, _enable_server_process_global_profiler, - _set_and_start_rpc_agent, - _reset_current_rpc_agent, - _delete_all_user_and_unforked_owner_rrefs, - _destroy_rref_context, - _set_profiler_node_id, - _is_current_rpc_agent_set, - _rref_context_get_debug_info, - _cleanup_python_rpc_handler, - _invoke_rpc_builtin, - _invoke_rpc_python_udf, - _invoke_rpc_torchscript, + _get_current_rpc_agent, _invoke_remote_builtin, _invoke_remote_python_udf, _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, _set_rpc_timeout, - _get_current_rpc_agent, - get_rpc_timeout, - enable_gil_profiling, - RpcBackendOptions, _TensorPipeRpcBackendOptionsBase, - RpcAgent, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, PyRRef, - TensorPipeAgent, RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + TensorPipeAgent, WorkerInfo, - _DEFAULT_INIT_METHOD, - _DEFAULT_NUM_WORKER_THREADS, - _UNSET_RPC_TIMEOUT, - _DEFAULT_RPC_TIMEOUT_SEC, - ) # noqa: F401 + ) from . import api, backend_registry, functions from .api import * # noqa: F401,F403 - import numbers - - import torch.distributed.autograd as dist_autograd - from .backend_registry import BackendType from .options import TensorPipeRpcBackendOptions # noqa: F401 - from .server_process_global_profiler import ( - _server_process_global_profile, - ) + from .server_process_global_profiler import _server_process_global_profile rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] @@ -153,7 +153,7 @@ def init_rpc( "corresponding to %(backend)s, hence that backend will be used " "instead of the default BackendType.TENSORPIPE. To silence this " "warning pass `backend=%(backend)s` explicitly.", - {'backend': backend} + {"backend": backend}, ) if backend is None: @@ -224,7 +224,6 @@ def _init_rpc_backend( world_size=None, rpc_backend_options=None, ): - _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) if _is_current_rpc_agent_set(): diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py index 640c4d09f06281..8ac1c02f4cee4c 100644 --- a/torch/distributed/rpc/_testing/__init__.py +++ b/torch/distributed/rpc/_testing/__init__.py @@ -12,8 +12,9 @@ def is_available(): if is_available(): # Registers FAULTY_TENSORPIPE RPC backend. - from . import faulty_agent_backend_registry from torch._C._distributed_rpc_testing import ( - FaultyTensorPipeRpcBackendOptions, FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, ) + + from . import faulty_agent_backend_registry diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py index 9e8660989e5a7c..d04882e16e79a9 100644 --- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py +++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -4,6 +4,7 @@ import torch.distributed as dist import torch.distributed.rpc as rpc + def _faulty_tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, @@ -11,7 +12,7 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( messages_to_fail, messages_to_delay, num_fail_sends, - **kwargs + **kwargs, ): from . import FaultyTensorPipeRpcBackendOptions @@ -28,16 +29,14 @@ def _faulty_tensorpipe_construct_rpc_backend_options_handler( def _faulty_tensorpipe_init_backend_handler( store, name, rank, world_size, rpc_backend_options ): - from . import FaultyTensorPipeAgent - from . import FaultyTensorPipeRpcBackendOptions from torch.distributed.rpc import api + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, FaultyTensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" ) diff --git a/torch/distributed/rpc/_utils.py b/torch/distributed/rpc/_utils.py index 6499a80e0e1724..8925bc662b5f97 100644 --- a/torch/distributed/rpc/_utils.py +++ b/torch/distributed/rpc/_utils.py @@ -1,12 +1,14 @@ # mypy: allow-untyped-defs +import logging from contextlib import contextmanager from typing import cast -import logging -from . import api -from . import TensorPipeAgent + +from . import api, TensorPipeAgent + logger = logging.getLogger(__name__) + @contextmanager def _group_membership_management(store, name, is_join): token_key = "RpcGroupManagementToken" @@ -29,10 +31,17 @@ def _group_membership_management(store, name, is_join): try: store.wait([returned]) except RuntimeError: - logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned) + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) raise + def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) - ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) return ret diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index a33358eb0dc674..5fc9e61aa5592a 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -1,6 +1,4 @@ # mypy: allow-untyped-defs -__all__ = ["shutdown", "get_worker_info", "remote", "rpc_sync", - "rpc_async", "RRef", "AllGatherStates", "method_factory", "new_method"] import collections import contextlib @@ -8,17 +6,10 @@ import inspect import logging import threading -from typing import Dict, Generic, TypeVar, Set, Any, TYPE_CHECKING +from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar import torch -from torch.futures import Future - from torch._C._distributed_rpc import ( - PyRRef, - RemoteProfilerManager, - WorkerInfo, - TensorPipeAgent, - get_rpc_timeout, _cleanup_python_rpc_handler, _delete_all_user_and_unforked_owner_rrefs, _destroy_rref_context, @@ -32,18 +23,36 @@ _is_current_rpc_agent_set, _reset_current_rpc_agent, _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + TensorPipeAgent, + WorkerInfo, ) +from torch.futures import Future +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, PythonUDF, RPCExecMode, - _internal_rpc_pickler, - _build_rpc_profiling_key, ) -from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT -from ._utils import _group_membership_management, _update_group_membership +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + logger = logging.getLogger(__name__) @@ -59,6 +68,7 @@ _ignore_rref_leak = True _default_pickler = _internal_rpc_pickler + @contextlib.contextmanager def _use_rpc_pickler(rpc_pickler): r""" @@ -107,7 +117,9 @@ def __init__(self): _ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id: Dict[str, int] = {} -_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) def _init_rpc_states(agent): @@ -146,6 +158,7 @@ def _broadcast_to_followers(sequence_id, objects_map): states.gathered_objects = objects_map states.proceed_signal.set() + _thread_local_var = threading.local() @@ -245,7 +258,7 @@ def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): follower_name, _broadcast_to_followers, args=(sequence_id, states.gathered_objects), - timeout=rpc_timeout + timeout=rpc_timeout, ) worker_name_to_response_future_dict[follower_name] = fut @@ -283,9 +296,7 @@ def _barrier(worker_names): try: _all_gather(None, set(worker_names)) except RuntimeError as ex: - logger.error( - "Failed to complete barrier, got error %s", ex - ) + logger.error("Failed to complete barrier, got error %s", ex) @_require_initialized @@ -371,7 +382,11 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): all_worker_infos = agent.get_worker_infos() for worker in all_worker_infos: if worker.name != my_name: - rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False)) + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) agent.join(shutdown=True, timeout=timeout) finally: # In case of errors, continue to complete the local shutdown. @@ -445,13 +460,10 @@ def _rref_typeof_on_owner(rref, blocking: bool = True): return future -def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True): - fut = rpc_async( - rref.owner(), - _rref_typeof_on_owner, - args=(rref,), - timeout=timeout - ) +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) if blocking: return fut.wait() else: @@ -463,13 +475,16 @@ def _rref_typeof_on_user(rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: boo if TYPE_CHECKING: + class RRef(PyRRef[T], Generic[T]): pass + else: try: # Combine the implementation class and the type class. class RRef(PyRRef, Generic[T]): pass + except TypeError: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases @@ -517,7 +532,9 @@ def method(self, *args, **kwargs): assert docstring is not None, "RRef user-facing methods should all have docstrings." # Do surgery on pybind11 generated docstrings. - docstring = docstring.replace("torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef") + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) # Attach user-facing RRef method with modified docstring. new_method = method_factory(method_name, docstring) @@ -633,7 +650,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info = _to_worker_info(to) should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -647,7 +666,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): func = wrapped if qualified_name is not None: - rref = _invoke_remote_builtin(dst_worker_info, qualified_name, timeout, *args, **kwargs) + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) elif isinstance(func, torch.jit.ScriptFunction): rref = _invoke_remote_torchscript( dst_worker_info.name, @@ -662,11 +683,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): PythonUDF(func, args, kwargs) ) rref = _invoke_remote_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec ) # attach profiling information if should_profile: @@ -678,7 +695,9 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): return rref -def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT): +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): if not callable(func): raise TypeError("function should be callable.") @@ -687,7 +706,9 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = should_profile = _get_should_profile() - ctx_manager = _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info) + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) with ctx_manager as rf: args = args if args else () @@ -702,11 +723,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = if qualified_name is not None: fut = _invoke_rpc_builtin( - dst_worker_info, - qualified_name, - rpc_timeout, - *args, - **kwargs + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs ) elif isinstance(func, torch.jit.ScriptFunction): fut = _invoke_rpc_torchscript( @@ -715,18 +732,14 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = args, kwargs, rpc_timeout, - is_async_exec + is_async_exec, ) else: (pickled_python_udf, tensors) = _default_pickler.serialize( PythonUDF(func, args, kwargs) ) fut = _invoke_rpc_python_udf( - dst_worker_info, - pickled_python_udf, - tensors, - rpc_timeout, - is_async_exec + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec ) if should_profile: assert torch.autograd._profiler_enabled() @@ -915,12 +928,15 @@ def _get_should_profile(): # Kineto profiler. ActiveProfilerType = torch._C._profiler.ActiveProfilerType return ( - torch.autograd._profiler_enabled() and - torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() + == ActiveProfilerType.LEGACY # type: ignore[attr-defined] ) -def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info): +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): ctx_manager = contextlib.nullcontext() if should_profile: diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6290f9e8e2054b..a06f0276ede95a 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -__all__ = ["init_backend", "backend_registered", "construct_rpc_backend_options", "register_backend", "BackendType", "BackendValue"] + import collections import enum @@ -7,13 +7,19 @@ import torch import torch.distributed as dist + +from . import api, constants as rpc_constants from ._utils import _group_membership_management, _update_group_membership -from . import api -from . import constants as rpc_constants -__all__ = ["backend_registered", "register_backend", "construct_rpc_backend_options", "init_backend", - "BackendValue", "BackendType"] +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] BackendValue = collections.namedtuple( "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] @@ -41,6 +47,7 @@ def _backend_type_repr(self): if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc + def backend_registered(backend_name): """ Checks if backend_name is registered as an RPC backend. @@ -80,7 +87,7 @@ def register_backend( init_backend_handler=init_backend_handler, ) }, - **existing_enum_dict + **existing_enum_dict, ) # Can't handle Function Enum API (mypy bug #9079) BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] @@ -90,20 +97,22 @@ def register_backend( BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] + def construct_rpc_backend_options( backend, rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, init_method=rpc_constants.DEFAULT_INIT_METHOD, - **kwargs + **kwargs, ): - return backend.value.construct_rpc_backend_options_handler( rpc_timeout, init_method, **kwargs ) + def init_backend(backend, *args, **kwargs): return backend.value.init_backend_handler(*args, **kwargs) + def _init_process_group(store, rank, world_size): # Initialize ProcessGroup. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT @@ -115,22 +124,21 @@ def _init_process_group(store, rank, world_size): assert group is not None, "Failed to initialize default ProcessGroup." if (rank != -1) and (rank != group.rank()): - raise RuntimeError( - f"rank argument {rank} doesn't match pg rank {group.rank()}" - ) + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") if (world_size != -1) and (world_size != group.size()): raise RuntimeError( f"world_size argument {world_size} doesn't match pg size {group.size()}" ) return group + def _tensorpipe_construct_rpc_backend_options_handler( rpc_timeout, init_method, num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, _transports=None, _channels=None, - **kwargs + **kwargs, ): from . import TensorPipeRpcBackendOptions @@ -155,9 +163,9 @@ def _tensorpipe_validate_devices(devices, device_count): def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): - gathered: List[Tuple[ - str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device] - ]] = [("", 0, {}, []) for _ in range(group.size())] + gathered: List[ + Tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group ) @@ -173,13 +181,15 @@ def _tensorpipe_exchange_and_check_all_device_maps( my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) return reverse_device_maps, my_devices -def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True): + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): for node in all_names: devices = all_devices[node] if len(set(devices)) != len(devices): raise ValueError( - f"Node {node} has duplicated devices\n" - f"devices = {devices}" + f"Node {node} has duplicated devices\n" f"devices = {devices}" ) if not _tensorpipe_validate_devices(devices, all_device_counts[node]): raise ValueError( @@ -190,7 +200,9 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev for source_node in all_names: # For dynamic group (non-static) do not check the target node name since it may not have joined yet - if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names): + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): raise ValueError( f"Node {source_node} has invalid target node names in its device maps\n" f"device maps = {all_device_maps[source_node].keys()}\n" @@ -238,6 +250,7 @@ def _validate_device_maps(all_names, all_device_counts, all_device_maps, all_dev f"device count = {all_device_counts[target_node]}" ) + def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: devices_set: Set[torch.device] = set() @@ -250,6 +263,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps): my_devices = sorted(my_devices, key=lambda d: d.index) return my_devices + def _create_reverse_mapping(my_name, all_names, all_device_maps): reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} for node in all_names: @@ -259,8 +273,10 @@ def _create_reverse_mapping(my_name, all_names, all_device_maps): } return reverse_device_maps + def _get_device_infos(): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) opts = agent._get_backend_options() device_count = torch.cuda.device_count() @@ -268,8 +284,10 @@ def _get_device_infos(): torch.cuda.init() return device_count, opts.device_maps, opts.devices + def _set_devices_and_reverse_device_map(agent): from . import TensorPipeAgent + agent = cast(TensorPipeAgent, agent) # Group state is retrieved from local agent # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid @@ -282,34 +300,52 @@ def _set_devices_and_reverse_device_map(agent): worker_name = worker_info.name if worker_name != my_name: # TODO: make async? - device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos) + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) else: opts = agent._get_backend_options() - device_count, device_map, devices = torch.cuda.device_count(), opts.device_maps, opts.devices + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) all_device_counts[worker_name] = device_count all_device_maps[worker_name] = device_map all_devices[worker_name] = devices all_names.append(worker_name) - _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False) + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps for worker_name in all_names: # Set device list for each worker - all_devices[worker_name] = _create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps) - api.rpc_sync(worker_name, _update_group_membership, - args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True)) + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions -def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_options): - from . import TensorPipeAgent - from . import TensorPipeRpcBackendOptions if not isinstance(store, dist.Store): raise TypeError(f"`store` must be a c10d::Store. {store}") - if not isinstance( - rpc_backend_options, TensorPipeRpcBackendOptions - ): + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): raise TypeError( f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" ) @@ -389,6 +425,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_ raise return agent + register_backend( "TENSORPIPE", _tensorpipe_construct_rpc_backend_options_handler, diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index 3bc525b70d9bb1..56f6db4db259df 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,5 +1,6 @@ from datetime import timedelta from typing import List + from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_WORKER_THREADS, @@ -17,7 +18,7 @@ DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index c9e92980cf5662..e48ea8cc534ab8 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -159,9 +159,11 @@ def async_execution(fn): >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() >>> print(ret) # prints tensor([4., 4.]) """ + @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) + # Can't declare and use attributes of function objects (mypy#2087) wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 2fc647c414d969..5faf7d14d0da57 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -12,6 +12,7 @@ import torch.distributed as dist from torch._C._distributed_rpc import _get_current_rpc_agent + __all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -251,7 +252,9 @@ def _build_rpc_profiling_key( Returns: String representing profiling key """ - profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) return profile_key diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 70328f34596958..53bf473ba56287 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -3,6 +3,7 @@ import torch from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase + from . import constants as rpc_contants @@ -10,6 +11,7 @@ __all__ = ["TensorPipeRpcBackendOptions"] + def _to_device(device: DeviceType) -> torch.device: device = torch.device(device) if device.type != "cuda": diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index cdb0a5d22b7423..85927b68bacb9c 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -1,20 +1,22 @@ # mypy: allow-untyped-defs from functools import partial -from . import functions -from . import rpc_async - import torch -from .constants import UNSET_RPC_TIMEOUT from torch.futures import Future +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + @functions.async_execution def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) + def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _rref_type_cont(rref_fut): rref_type = rref_fut.value() @@ -33,7 +35,7 @@ def _rref_type_cont(rref_fut): rref.owner(), _invoke_func, args=(rref, func_name, args, kwargs), - timeout=timeout + timeout=timeout, ) rref_fut = rref._get_type(timeout=timeout, blocking=False) @@ -63,6 +65,7 @@ def _complete_op(fut): rref_fut.then(_wrap_rref_type_cont) return result + # This class manages proxied RPC API calls for RRefs. It is entirely used from # C++ (see python_rpc_handler.cpp). class RRefProxy: @@ -72,4 +75,6 @@ def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rpc_timeout = timeout def __getattr__(self, func_name): - return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout) + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 0543ab56a877fb..b5d089d305253f 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -2,18 +2,20 @@ # mypy: allow-untyped-defs import itertools +from typing import List import torch from torch.autograd.profiler_legacy import profile -from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) + __all__: List[str] = [] + class _server_process_global_profile(profile): """ It has the same API as ``torch.autograd.profiler.profile`` class, @@ -123,7 +125,8 @@ def __enter__(self): False, False, False, - torch.profiler._ExperimentalConfig()) + torch.profiler._ExperimentalConfig(), + ) _enable_server_process_global_profiler(profiler_config) return self @@ -152,8 +155,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( - thread_local_events + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) ) thread_local_function_events.sort( key=lambda function_event: [ From a0e1e20c4157bb3e537fc784a51d7aef1e754157 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 18 Jun 2024 23:21:45 +0800 Subject: [PATCH 158/171] [BE][Easy] enable UFMT for `torch/distributed/` (#128870) Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin ghstack dependencies: #128868, #128869 --- .lintrunner.toml | 27 - torch/distributed/__init__.py | 64 +-- .../_composable/fsdp/_fsdp_collectives.py | 1 + .../_composable/fsdp/_fsdp_common.py | 1 - .../_composable/fsdp/_fsdp_init.py | 2 +- .../_composable/fsdp/_fsdp_param.py | 3 +- .../_composable/fsdp/_fsdp_param_group.py | 3 +- .../_composable/fsdp/_fsdp_state.py | 3 +- .../_composable/fsdp/fully_shard.py | 1 - torch/distributed/_composable/fully_shard.py | 1 - torch/distributed/_composable/replicate.py | 1 + torch/distributed/_cuda_p2p/__init__.py | 3 +- torch/distributed/_functional_collectives.py | 2 + .../_functional_collectives_impl.py | 1 + torch/distributed/_sharded_tensor/__init__.py | 7 +- torch/distributed/_sharding_spec/__init__.py | 7 +- torch/distributed/_state_dict_utils.py | 1 + torch/distributed/_tools/memory_tracker.py | 19 +- torch/distributed/c10d_logger.py | 12 +- torch/distributed/collective_utils.py | 14 +- torch/distributed/constants.py | 7 +- torch/distributed/device_mesh.py | 3 +- torch/distributed/distributed_c10d.py | 520 +++++++++++++----- .../examples/memory_tracker_example.py | 2 +- torch/distributed/launcher/__init__.py | 2 +- torch/distributed/launcher/api.py | 13 +- torch/distributed/logging_handlers.py | 1 + torch/distributed/nn/__init__.py | 5 +- torch/distributed/nn/api/remote_module.py | 27 +- torch/distributed/nn/functional.py | 21 +- torch/distributed/pipelining/_IR.py | 6 +- torch/distributed/pipelining/__init__.py | 1 + torch/distributed/remote_device.py | 17 +- torch/distributed/rendezvous.py | 33 +- torch/distributed/run.py | 49 +- torch/distributed/utils.py | 1 + 36 files changed, 583 insertions(+), 298 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 99c04cac4fbb39..2c3da39f80ccfa 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1389,33 +1389,6 @@ exclude_patterns = [ 'torch/contrib/_tensorboard_vis.py', "torch/cuda/_gpu_trace.py", 'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable - 'torch/distributed/__init__.py', - 'torch/distributed/_composable_state.py', - 'torch/distributed/_sharded_tensor/__init__.py', - 'torch/distributed/_sharding_spec/__init__.py', - 'torch/distributed/_tools/__init__.py', - 'torch/distributed/_tools/memory_tracker.py', - 'torch/distributed/argparse_util.py', - 'torch/distributed/c10d_logger.py', - 'torch/distributed/collective_utils.py', - 'torch/distributed/constants.py', - 'torch/distributed/distributed_c10d.py', - 'torch/distributed/examples/memory_tracker_example.py', - 'torch/distributed/launch.py', - 'torch/distributed/launcher/__init__.py', - 'torch/distributed/launcher/api.py', - 'torch/distributed/logging_handlers.py', - 'torch/distributed/nn/__init__.py', - 'torch/distributed/nn/api/__init__.py', - 'torch/distributed/nn/api/remote_module.py', - 'torch/distributed/nn/functional.py', - 'torch/distributed/nn/jit/__init__.py', - 'torch/distributed/nn/jit/instantiator.py', - 'torch/distributed/nn/jit/templates/__init__.py', - 'torch/distributed/nn/jit/templates/remote_module_template.py', - 'torch/distributed/remote_device.py', - 'torch/distributed/rendezvous.py', - 'torch/distributed/run.py', 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index eb339000e89e7e..93b701732206fc 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs -import sys import pdb +import sys import torch + def is_available() -> bool: """ Return ``True`` if the distributed package is available. @@ -29,31 +30,31 @@ def is_available() -> bool: if is_available(): from torch._C._distributed_c10d import ( - Store, - FileStore, - TCPStore, - ProcessGroup as ProcessGroup, - Backend as _Backend, - PrefixStore, - Reducer, - Logger, - BuiltinCommHookType, - GradBucket, - Work as _Work, - _DEFAULT_FIRST_BUCKET_BYTES, - _register_comm_hook, - _register_builtin_comm_hook, _broadcast_coalesced, _compute_bucket_assignment_by_size, - _verify_params_across_processes, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, DebugLevel, + FileStore, get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, set_debug_level, set_debug_level_from_env, - _make_nccl_premul_sum, - _ControlCollectives, - _StoreCollectives, + Store, + TCPStore, + Work as _Work, ) class _DistributedPdb(pdb.Pdb): @@ -63,10 +64,11 @@ class _DistributedPdb(pdb.Pdb): Usage: _DistributedPdb().set_trace() """ + def interaction(self, *args, **kwargs): _stdin = sys.stdin try: - sys.stdin = open('/dev/stdin') + sys.stdin = open("/dev/stdin") pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin @@ -98,37 +100,31 @@ def breakpoint(rank: int = 0): del guard if sys.platform != "win32": - from torch._C._distributed_c10d import ( - HashStore, - _round_robin_process_groups, - ) + from torch._C._distributed_c10d import _round_robin_process_groups, HashStore - from .distributed_c10d import * # noqa: F403 + from .device_mesh import DeviceMesh, init_device_mesh # Variables prefixed with underscore are not auto imported # See the comment in `distributed_c10d.py` above `_backend` on why we expose # this. - + from .distributed_c10d import * # noqa: F403 from .distributed_c10d import ( _all_gather_base, - _reduce_scatter_base, - _create_process_group_wrapper, - _rank_not_in_group, _coalescing_manager, _CoalescingManager, + _create_process_group_wrapper, _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, get_node_local_rank, ) - + from .remote_device import _remote_device from .rendezvous import ( - rendezvous, _create_store_from_options, register_rendezvous_handler, + rendezvous, ) - from .remote_device import _remote_device - from .device_mesh import init_device_mesh, DeviceMesh - set_debug_level_from_env() else: diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 1423cfd600fc88..14f7f8a313faf4 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -5,6 +5,7 @@ import torch.distributed as dist from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp + from ._fsdp_common import ( _get_dim0_padded_size, _raise_assert_with_print, diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 594ec483bd3bf6..36b181250f28da 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import math import traceback - from dataclasses import dataclass from enum import auto, Enum from typing import Any, cast, List, Optional diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 07fd45e9e3d71e..141addc6b71913 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist import torch.nn as nn - from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import _get_device_handle from torch.utils._python_dispatch import is_traceable_wrapper_subclass + from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo from ._fsdp_state import _get_module_fsdp_state diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index c56dc79e266bb7..6e0e815f7a537e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -7,12 +7,12 @@ import torch import torch._dynamo.compiled_autograd as ca import torch.nn as nn - from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed._tensor.device_mesh import _mesh_resources from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta + from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( _chunk_with_empty, @@ -24,6 +24,7 @@ HSDPMeshInfo, ) + """ [Note: FSDP tensors] FSDP considers the following tensors: diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 06fa90e060e70d..6592a815bacfa6 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import contextlib - from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -11,6 +10,7 @@ from torch.profiler import record_function from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle + from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( AllGatherResult, @@ -21,6 +21,7 @@ from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState + _ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 79a09342704ff1..c6cdb2b29880bf 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch @@ -13,10 +12,12 @@ ) from torch.distributed.utils import _to_kwargs from torch.utils._pytree import tree_flatten, tree_map + from ._fsdp_api import MixedPrecisionPolicy from ._fsdp_common import _cast_fp_tensor, TrainingState from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + if TYPE_CHECKING: from ._fsdp_param import FSDPParam diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 61b7878d467ff2..e8ab3466118bc7 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools - from typing import Any, cast, Iterable, List, NoReturn, Optional, Union import torch diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 950a034071a43d..06b121aef80a8e 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -8,7 +8,6 @@ from torch.distributed._composable_state import _get_module_state, _insert_module_state from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo - from torch.distributed.fsdp._init_utils import ( _init_buffer_state, _init_core_state, diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index 0cb4ea79bc7d1c..6ba70cf7bfc936 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -9,6 +9,7 @@ from .contract import _get_registry, contract + _ROOT_MODULE_PREFIX = "" diff --git a/torch/distributed/_cuda_p2p/__init__.py b/torch/distributed/_cuda_p2p/__init__.py index 1d3f24c80f08a4..a3998c8e1d3b47 100644 --- a/torch/distributed/_cuda_p2p/__init__.py +++ b/torch/distributed/_cuda_p2p/__init__.py @@ -1,15 +1,14 @@ # mypy: allow-untyped-defs from collections import defaultdict from contextlib import contextmanager - from functools import partial from typing import Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch import torch.distributed._functional_collectives as funcol - import torch.distributed.distributed_c10d as c10d + if TYPE_CHECKING: from torch._C._distributed_c10d import _DistributedBackendOptions, Backend diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 9ac89166b25fd4..82ca3cb8b07385 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -11,6 +11,7 @@ from . import _functional_collectives_impl as fun_col_impl + try: from torch.utils._cxx_pytree import tree_map_only except ImportError: @@ -1134,6 +1135,7 @@ def all_gather_inplace( reduce_scatter_tensor as legacy_reducescatter, ) + # This dict should contain sets of functions that dynamo is allowed to remap. # Functions in this set should accept the same args/kwargs 1:1 as their mapping. traceable_collective_remaps = { diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index c39cb4a9d50d10..4bd193d662bd6d 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -4,6 +4,7 @@ import torch import torch.distributed.distributed_c10d as c10d + """ This file contains the op impls for the legacy (c10d_functional) functional collectives. These impls simply call into the native (_c10d_functional) functional collectives. diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index 6c6694cfb08139..5e6f4d2a1a6ec4 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharded_tensor import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -15,4 +16,6 @@ stacklevel=2, ) -sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor +sys.modules[ + "torch.distributed._sharded_tensor" +] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index 21c56d5dc849eb..c74dd3633e0f5e 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -1,11 +1,12 @@ # Keep old package for BC purposes, this file should be removed once # everything moves to the `torch.distributed._shard` package. import sys -import torch import warnings +import torch from torch.distributed._shard.sharding_spec import * # noqa: F403 + with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( @@ -16,4 +17,6 @@ ) import torch.distributed._shard.sharding_spec as _sharding_spec -sys.modules['torch.distributed._sharding_spec'] = _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 2f9f0555be641a..cb9def721686cf 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor + if dist.is_available() or TYPE_CHECKING: from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 10f70c9ce18e79..e4d8aa6e762b8f 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -1,24 +1,14 @@ # mypy: allow-untyped-defs +import operator +import pickle from collections import defaultdict - from itertools import chain - -import pickle - -from typing import ( - Any, - Callable, - Dict, - List, - no_type_check, - Sequence, - TYPE_CHECKING, -) +from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING import torch import torch.nn as nn from torch.utils._python_dispatch import TorchDispatchMode -import operator + if TYPE_CHECKING: from torch.utils.hooks import RemovableHandle @@ -234,6 +224,7 @@ def load(self, path: str) -> None: def _create_pre_forward_hook(self, name: str) -> Callable: """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: self._cur_module_name = f"{name}.forward" if ( diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index c1cc67b40681b1..2c92176c53eb2a 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -15,9 +15,9 @@ import torch import torch.distributed as dist - from torch.distributed.logging_handlers import _log_handlers + __all__: List[str] = [] _DEFAULT_DESTINATION = "default" @@ -36,7 +36,9 @@ def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Lo return logger -def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[logging.Handler, str]: +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> Tuple[logging.Handler, str]: log_handler = _log_handlers[destination] log_handler_name = type(log_handler).__name__ return (log_handler, log_handler_name) @@ -69,8 +71,10 @@ def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: } return msg_dict -_T = TypeVar('_T') -_P = ParamSpec('_P') + +_T = TypeVar("_T") +_P = ParamSpec("_P") + def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(func) diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index ed6c93078299a4..78199e7a26f22c 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -14,8 +14,10 @@ import torch.distributed as dist + T = TypeVar("T") + @dataclass class SyncPayload(Generic[T]): stage_name: Optional[str] @@ -23,6 +25,7 @@ class SyncPayload(Generic[T]): payload: T exception: Optional[Exception] = None + def broadcast( data_or_fn: Union[T, Callable[[], T]], *, @@ -55,10 +58,12 @@ def broadcast( """ if not success and data_or_fn is not None: - raise AssertionError("Data or Function is expected to be None if not successful") + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None # if no pg is passed then execute if rank is 0 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): # determine if it is an executable function or data payload only @@ -119,7 +124,7 @@ def all_gather( >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) """ payload: Optional[T] = None - exception : Optional[Exception] = None + exception: Optional[Exception] = None success = True # determine if it is an executable function or data payload only if callable(data_or_fn): @@ -161,7 +166,8 @@ def all_gather( if len(exception_list) > 0: raise RuntimeError( # type: ignore[misc] - error_msg, exception_list) from exception_list[0] + error_msg, exception_list + ) from exception_list[0] return ret_list else: if not sync_obj.success: diff --git a/torch/distributed/constants.py b/torch/distributed/constants.py index 47b1f90e406c5e..b3754043644b8c 100644 --- a/torch/distributed/constants.py +++ b/torch/distributed/constants.py @@ -1,8 +1,10 @@ -from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT from datetime import timedelta from typing import Optional -__all__ = ['default_pg_timeout', 'default_pg_nccl_timeout'] +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] # Default process group wide timeout, if applicable. # This only applies to the non-nccl backends @@ -16,6 +18,7 @@ try: from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + default_pg_nccl_timeout: Optional[timedelta] = _DEFAULT_PG_NCCL_TIMEOUT except ImportError: # if C++ NCCL support is not compiled, we don't have access to the default nccl value. diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index e46356a3689422..a1fee846d2545c 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -6,10 +6,9 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch - from torch.distributed import is_available +from torch.utils._typing_utils import not_none -from ..utils._typing_utils import not_none __all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d44c3733a214e6..91e4cf9f540c8d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,11 +1,11 @@ # mypy: allow-untyped-defs """Distributed Collective Communication (c10d).""" -import itertools import collections.abc import contextlib import hashlib import io +import itertools import logging import os import pickle @@ -14,19 +14,26 @@ import warnings from collections import namedtuple from datetime import timedelta -from typing import Any, Callable, Dict, Optional, Tuple, Union, List, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import deprecated import torch +from torch._C import _DistStoreError as DistStoreError from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, AllToAllOptions, - _DistributedBackendOptions, BarrierOptions, BroadcastOptions, + DebugLevel, GatherOptions, + get_debug_level, PrefixStore, ProcessGroup, ReduceOp, @@ -34,41 +41,88 @@ ReduceScatterOptions, ScatterOptions, Store, - DebugLevel, - get_debug_level, Work, - _register_process_group, - _resolve_process_group, - _unregister_all_process_groups, - _unregister_process_group, ) from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs -from .constants import default_pg_timeout, default_pg_nccl_timeout +from torch.utils._typing_utils import not_none + from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 -from ..utils._typing_utils import not_none -DistStoreError = torch._C._DistStoreError + __all__ = [ - 'Backend', 'BackendConfig', 'GroupMember', 'P2POp', 'all_gather', 'all_gather_coalesced', - 'all_gather_object', 'all_reduce', - 'all_reduce_coalesced', 'all_to_all', - 'all_to_all_single', 'barrier', 'batch_isend_irecv', 'broadcast', 'send_object_list', - 'recv_object_list', 'broadcast_object_list', 'destroy_process_group', - 'gather', 'gather_object', 'get_backend_config', 'get_backend', 'get_rank', - 'get_world_size', 'get_pg_count', 'group', 'init_process_group', 'irecv', - 'is_gloo_available', 'is_initialized', 'is_mpi_available', 'is_backend_available', - 'is_nccl_available', 'is_torchelastic_launched', 'is_ucc_available', - 'isend', 'monitored_barrier', 'new_group', 'new_subgroups', - 'new_subgroups_by_enumeration', 'recv', 'reduce', - 'reduce_scatter', 'scatter', - 'scatter_object_list', 'send', 'supports_complex', - 'AllreduceCoalescedOptions', 'AllreduceOptions', 'AllToAllOptions', - 'BarrierOptions', 'BroadcastOptions', 'GatherOptions', 'PrefixStore', - 'ProcessGroup', 'ReduceOp', 'ReduceOptions', 'ReduceScatterOptions', - 'ScatterOptions', 'Store', 'DebugLevel', 'get_debug_level', 'Work', - 'default_pg_timeout', 'get_group_rank', 'get_global_rank', 'get_process_group_ranks', - 'reduce_op', 'all_gather_into_tensor', 'reduce_scatter_tensor', 'get_node_local_rank', + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", ] _MPI_AVAILABLE = True @@ -79,6 +133,7 @@ _pickler = pickle.Pickler _unpickler = pickle.Unpickler + # Change __module__ of all imported types from torch._C._distributed_c10d that are public def _export_c_types() -> None: _public_types_to_change_module = [ @@ -97,22 +152,25 @@ def _export_c_types() -> None: Store, DebugLevel, get_debug_level, - Work + Work, ] for type in _public_types_to_change_module: type.__module__ = "torch.distributed.distributed_c10d" + + _export_c_types() try: from torch._C._distributed_c10d import ProcessGroupMPI + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupMPI"] except ImportError: _MPI_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupNCCL - from torch._C._distributed_c10d import ProcessGroupCudaP2P + from torch._C._distributed_c10d import ProcessGroupCudaP2P, ProcessGroupNCCL + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"] @@ -120,8 +178,8 @@ def _export_c_types() -> None: _NCCL_AVAILABLE = False try: - from torch._C._distributed_c10d import ProcessGroupGloo - from torch._C._distributed_c10d import _ProcessGroupWrapper + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupGloo"] except ImportError: @@ -129,6 +187,7 @@ def _export_c_types() -> None: try: from torch._C._distributed_c10d import ProcessGroupUCC + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" __all__ += ["ProcessGroupUCC"] except ImportError: @@ -191,20 +250,20 @@ class Backend(str): backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { - 'cpu' : GLOO, - 'cuda' : NCCL, + "cpu": GLOO, + "cuda": NCCL, } backend_capability: Dict[str, List[str]] = { - GLOO : ["cpu", "cuda"], - NCCL : ["cuda"], - UCC : ["cpu", "cuda"], - MPI : ["cpu", "cuda"], + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], } backend_type_map: Dict[str, ProcessGroup.BackendType] = { UNDEFINED: ProcessGroup.BackendType.UNDEFINED, - GLOO : ProcessGroup.BackendType.GLOO, + GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -220,7 +279,13 @@ def __new__(cls, name: str): return value @classmethod - def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None) -> None: + def register_backend( + cls, + name, + func, + extended_api=False, + devices: Optional[Union[str, List[str]]] = None, + ) -> None: """ Register a new backend with the given name and instantiating function. @@ -247,19 +312,19 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio """ # Allow UCC plugin if Pytorch is not built with native support. # TODO: remove this exception once UCC plugin is fully deprecated. - if (name != Backend.UCC or (name == Backend.UCC and is_ucc_available())): - assert not hasattr(Backend, name.upper()), ( - f"{name.upper()} c10d backend already exist" - ) - assert name.upper() not in Backend._plugins, ( - f"{name.upper()} c10d backend creator function already exist" - ) + if name != Backend.UCC or (name == Backend.UCC and is_ucc_available()): + assert not hasattr( + Backend, name.upper() + ), f"{name.upper()} c10d backend already exist" + assert ( + name.upper() not in Backend._plugins + ), f"{name.upper()} c10d backend creator function already exist" setattr(Backend, name.upper(), name.lower()) Backend.backend_list.append(name.lower()) if devices is not None: for device in devices: - if device != 'cpu' and device != 'cuda': + if device != "cpu" and device != "cuda": Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM @@ -281,6 +346,7 @@ def register_backend(cls, name, func, extended_api=False, devices: Optional[Unio Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + class BackendConfig: """Backend configuration class.""" @@ -294,7 +360,10 @@ def __init__(self, backend: Backend): # supported since PyTorch 2.0 for device, default_backend in Backend.default_device_backend_map.items(): if is_backend_available(default_backend): - if default_backend == Backend.NCCL and not torch.cuda.is_available(): + if ( + default_backend == Backend.NCCL + and not torch.cuda.is_available() + ): continue self.device_backend_map[device] = Backend(default_backend) elif backend.lower() in Backend.backend_list: @@ -316,12 +385,16 @@ def __init__(self, backend: Backend): for device_backend_pair_str in backend.lower().split(","): device_backend_pair = device_backend_pair_str.split(":") if len(device_backend_pair) != 2: - raise ValueError(f"Invalid device:backend pairing: \ - {device_backend_pair_str}. {backend_str_error_message}") + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) device, backend = device_backend_pair if device in self.device_backend_map: - raise ValueError(f"Duplicate device type {device} \ - in backend string: {backend}. {backend_str_error_message}") + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) self.device_backend_map[device] = Backend(backend) else: # User specified a single backend name whose device capability is @@ -334,23 +407,24 @@ def __init__(self, backend: Backend): ) backend_val = Backend(backend) self.device_backend_map = { - "cpu" : backend_val, - "cuda" : backend_val, - "xpu" : backend_val, + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, } - logger.info( - "Using backend config: %s", self.device_backend_map - ) + logger.info("Using backend config: %s", self.device_backend_map) def __repr__(self): """Return all the device:backend pairs separated by commas.""" - return ",".join(f"{device}:{backend}" for device, backend in self.device_backend_map.items()) + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) def get_device_backend_map(self) -> Dict[str, Backend]: """Return backend map of the device.""" return self.device_backend_map + class _reduce_op: r""" Deprecated enum-like class. @@ -397,8 +471,14 @@ class P2POp: tag (int, optional): Tag to match send with recv. """ - def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Init.""" self.op = op self.tensor = tensor @@ -406,8 +486,14 @@ def __init__(self, op: Callable, tensor: torch.Tensor, peer: int, self.group = group self.tag = tag - def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, - group: Optional[ProcessGroup] = None, tag: int = 0): + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: int, + group: Optional[ProcessGroup] = None, + tag: int = 0, + ): """Create and return a new instance of the class.""" _check_op(op) _check_single_tensor(tensor, "tensor") @@ -415,7 +501,9 @@ def __new__(cls, op: Callable, tensor: torch.Tensor, peer: int, def __repr__(self): my_group_rank = get_rank(self.group) - peer_group_rank = get_group_rank(self.group, self.peer) if self.group else self.peer + peer_group_rank = ( + get_group_rank(self.group, self.peer) if self.group else self.peer + ) op_name = self.op.__name__ group_name = self.group.group_name if self.group else "default_pg" if "send" in op_name: @@ -429,6 +517,7 @@ def __repr__(self): return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})" + class _CollOp: """ A class to capture collective operations. @@ -441,8 +530,14 @@ class _CollOp: root (int, optional): root of broadcast or reduce. """ - def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torch.Tensor] = None, - redop: Optional[ReduceOp] = None, root: Optional[int] = None): + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: Optional[torch.Tensor] = None, + redop: Optional[ReduceOp] = None, + root: Optional[int] = None, + ): self.op = op self.tensor = tensor self.dst_tensor = dst_tensor @@ -462,6 +557,7 @@ def __init__(self, op: Callable, tensor: torch.Tensor, dst_tensor: Optional[torc _pg_to_tag: Dict[ProcessGroup, str] = {} _backend: Optional[str] = None + class _World: """ Container class for c10d process group state. @@ -597,6 +693,7 @@ def pg_config_info(self) -> List[Dict[str, Any]]: _world = _World() """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + class _WorldMeta(type): """ Meta class of ``group`` and ``GroupMember``. @@ -613,11 +710,13 @@ def WORLD(cls) -> Optional[ProcessGroup]: def WORLD(cls, pg: Optional[ProcessGroup]): _world.default_pg = pg + class group(metaclass=_WorldMeta): """Group class. Placeholder.""" pass + class GroupMember(metaclass=_WorldMeta): """Group member class.""" @@ -630,23 +729,28 @@ def _get_default_timeout(backend: Backend) -> timedelta: if not isinstance(default_pg_nccl_timeout, timedelta): # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was # changed to be a warning. We should fix the moco model. - warnings.warn("Attempted to get default timeout for nccl backend, but NCCL support is not compiled") + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled" + ) return default_pg_timeout return default_pg_nccl_timeout else: return default_pg_timeout + def _check_valid_timeout(timeout: Any) -> None: if not isinstance(timeout, timedelta): raise TypeError( f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" ) + # Default process group state _default_pg_init_method: Optional[str] = None STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device: """ Return the device to use with ``group`` for control flow usage (object collectives, barrier). @@ -711,14 +815,20 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device _world.pg_default_device[group] = devices[0] logger.info( - "Using device %s for object " - "collectives.", _world.pg_default_device[group] + "Using device %s for object " "collectives.", _world.pg_default_device[group] ) return _world.pg_default_device[group] @_time_logger -def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, logging_interval=timedelta(seconds=10)) -> None: +def _store_based_barrier( + rank, + store, + group_name, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: """ Store based barrier for synchronizing processes. @@ -755,7 +865,12 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log logger.debug( "Waiting in store based barrier to initialize process group for " "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", - rank, store_key, world_size, worker_count, timeout, e + rank, + store_key, + world_size, + worker_count, + timeout, + e, ) if timedelta(seconds=(time.time() - start)) > timeout: @@ -766,7 +881,10 @@ def _store_based_barrier(rank, store, group_name, rendezvous_count, timeout, log ) logger.info( - "Rank %s: Completed store-based barrier for key:%s with %s nodes.", rank, store_key, world_size + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, ) @@ -803,13 +921,16 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int: if group is GroupMember.WORLD: return global_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) group_ranks = _world.pg_group_ranks[group] if global_rank not in group_ranks: raise ValueError(f"Global rank {global_rank} is not part of group {group}") return group_ranks[global_rank] + def get_global_rank(group: ProcessGroup, group_rank: int) -> int: """ Translate a group rank into a global rank. @@ -828,7 +949,9 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int: if group is GroupMember.WORLD: return group_rank if group not in _world.pg_group_ranks: - raise ValueError(f"Group {group} is not registered, please create group with torch.distributed.new_group API") + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) for rank, grp_rank in _world.pg_group_ranks[group].items(): if grp_rank == group_rank: return rank @@ -858,6 +981,7 @@ def get_process_group_ranks(group: ProcessGroup) -> List[int]: """ return list(_world.pg_group_ranks[group].keys()) + def _get_group_size(group) -> int: """Get a given group's world size.""" if group is GroupMember.WORLD or group is None: @@ -906,13 +1030,16 @@ def _check_tensor_list(param, param_name) -> None: def _as_iterable(obj) -> collections.abc.Iterable: return obj if isinstance(obj, list) else (obj,) + def _ensure_all_tensors_same_dtype(*tensors) -> None: last_dtype = None for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): tensor_dtype = tensor.dtype # Mixing complex and its element type is allowed if tensor_dtype.is_complex: - tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) if last_dtype is None: last_dtype = tensor_dtype @@ -1049,6 +1176,7 @@ def _update_default_pg(pg) -> None: rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 torch._C._distributed_c10d._set_global_rank(rank) + def get_backend_config(group: Optional[ProcessGroup] = None) -> str: """ Return the backend configuration of the given process group. @@ -1071,6 +1199,7 @@ def get_backend_config(group: Optional[ProcessGroup] = None) -> str: backend_config = _world.pg_backend_config.get(pg) return str(not_none(backend_config)) + def get_backend(group: Optional[ProcessGroup] = None) -> Backend: """ Return the backend of the given process group. @@ -1093,6 +1222,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> Backend: pg_store = _world.pg_map[pg] if pg in _world.pg_map else None return Backend(not_none(pg_store)[0]) + def _get_process_group_uid(pg: ProcessGroup) -> int: backend = None try: @@ -1103,6 +1233,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int: return backend.uid return -1 + def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: """ Return the pg configuration of the given process group. @@ -1120,6 +1251,7 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: "ranks": get_process_group_ranks(pg), } + def _get_all_pg_configs() -> List[Dict[str, Any]]: """ Return the pg configuration of all the process groups. @@ -1130,6 +1262,7 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]: config_info.append(_get_pg_config(pg)) return config_info + def get_pg_count() -> int: """ Return the number of process groups. @@ -1137,6 +1270,7 @@ def get_pg_count() -> int: """ return _world.group_count + def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: """ Return the local rank of the current process relative to the node. @@ -1162,6 +1296,7 @@ def get_node_local_rank(fallback_rank: Optional[int] = None) -> int: "assuming you are not running in a multi-device context and want the code to run locally instead." ) + def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> None: """ Set the timeout for the given process group when users want to use a different timeout instead of @@ -1349,7 +1484,14 @@ def init_process_group( ) default_pg, _ = _new_process_group_helper( - -1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg" + -1, + -1, + [], + backend, + None, + group_name, + timeout=timeout, + group_desc="default_pg", ) _update_default_pg(default_pg) else: @@ -1375,7 +1517,7 @@ def init_process_group( pg_options=pg_options, timeout=timeout, device_id=device_id, - group_desc="default_pg" + group_desc="default_pg", ) _update_default_pg(default_pg) @@ -1394,7 +1536,9 @@ def _distributed_excepthook(*args): finally: sys.stderr = old_stderr msg = buf.getvalue() - msg = "\n".join(f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n")) + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) sys.stderr.write(msg) sys.stderr.flush() @@ -1421,6 +1565,7 @@ def _distributed_excepthook(*args): # default devices and messes up NCCL internal state. _store_based_barrier(rank, store, group_name, world_size, timeout) + def _get_split_source(pg): split_from = None if pg.bound_device_id: @@ -1442,6 +1587,7 @@ def _get_split_source(pg): return split_from + def _shutdown_backend(pg): """ Try to shut down the backend of a process group. @@ -1453,10 +1599,13 @@ def _shutdown_backend(pg): backend = pg._get_backend(torch.device("cuda")) except RuntimeError: pass - if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)): + if is_nccl_available() and isinstance( + backend, (ProcessGroupNCCL, ProcessGroupCudaP2P) + ): # explictly call shutdown to ensure that NCCL resources are released backend._shutdown() + def _new_process_group_helper( group_size, group_rank, @@ -1487,9 +1636,11 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != 'cuda'): - raise ValueError("init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu") + if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + raise ValueError( + "init_process_group device_id parameter must be a cuda device with an " + "id, e.g. cuda:0, not just cuda or cpu" + ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value _check_valid_timeout(timeout) @@ -1514,8 +1665,10 @@ def _new_process_group_helper( # ranks_. We can only know this if the group we are making is the # entire world or if we have bound a device id to the world (which # causes early connection initialization). - if (is_initialized() and - (len(global_ranks_in_group) == _get_default_group().size() or _get_default_group().bound_device_id)): + if is_initialized() and ( + len(global_ranks_in_group) == _get_default_group().size() + or _get_default_group().bound_device_id + ): split_from = _get_split_source(_get_default_group()) else: split_from = None @@ -1538,7 +1691,9 @@ def _new_process_group_helper( prefix_store = PrefixStore(f"{group_name}/", store) base_pg_options = ProcessGroup.Options(backend=str(backend)) base_pg_options._timeout = timeout - pg: ProcessGroup = ProcessGroup(prefix_store, group_rank, group_size, base_pg_options) + pg: ProcessGroup = ProcessGroup( + prefix_store, group_rank, group_size, base_pg_options + ) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1561,12 +1716,19 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None # create new process group with accurate rank and size if pg.rank() == -1 and pg.size() == -1: - pg = ProcessGroup(backend_prefix_store, backend_class.rank(), backend_class.size(), base_pg_options) + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + base_pg_options, + ) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: # raise RuntimeError("GLOO options not supported") - backend_class = ProcessGroupGloo(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupGloo( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.GLOO elif backend_str == Backend.NCCL: if not is_nccl_available(): @@ -1592,19 +1754,22 @@ def _new_process_group_helper( pg_options.global_ranks_in_group = global_ranks_in_group pg_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options) + backend_prefix_store, group_rank, group_size, pg_options + ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): # TODO: once UCC plugin is fully deprecated, remove # is_ucc_available() from above elif-condition and raise # RuntimeError if is_ucc_available() returns false. - backend_class = ProcessGroupUCC(backend_prefix_store, group_rank, group_size, timeout=timeout) + backend_class = ProcessGroupUCC( + backend_prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.UCC else: - assert backend_str.upper() in Backend._plugins, ( - f"Unknown c10d backend type {backend_str.upper()}" - ) + assert ( + backend_str.upper() in Backend._plugins + ), f"Unknown c10d backend type {backend_str.upper()}" backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -1612,7 +1777,9 @@ def _new_process_group_helper( backend_type = ProcessGroup.BackendType.CUSTOM if not extended_api: - backend_class = creator_fn(backend_prefix_store, group_rank, group_size, timeout) + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) else: dist_backend_opts = _DistributedBackendOptions() dist_backend_opts.store = backend_prefix_store @@ -1640,7 +1807,10 @@ def _new_process_group_helper( break # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set - if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] or backend_str.upper() in Backend._plugins: + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): # In debug mode and if GLOO is available, wrap in a wrapper PG that # enables enhanced collective checking for debuggability. if get_debug_level() == DebugLevel.DETAIL: @@ -1698,6 +1868,7 @@ def _new_process_group_helper( _world.pg_to_tag[pg] = pg_tag return pg, prefix_store + def destroy_process_group(group: Optional[ProcessGroup] = None): """ Destroy a given process group, and deinitialize the distributed package. @@ -1736,7 +1907,9 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): if group is None or group == GroupMember.WORLD: # shutdown all backends in the order of pg names. shutting down in order because # ncclCommAbort() was a 'collective' call in some versions of NCCL. - for pg_to_shutdown in sorted(_world.pg_names, key=lambda x: _world.pg_names[x], reverse=True): + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): _shutdown_backend(pg_to_shutdown) _update_default_pg(None) @@ -1832,7 +2005,9 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int: return _get_group_size(group) -def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: +def isend( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> Optional[Work]: """ Send a tensor asynchronously. @@ -1871,7 +2046,13 @@ def isend(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, return pg.send([tensor], dst, tag) -def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> Optional[Work]: + +def irecv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> Optional[Work]: """ Receives a tensor asynchronously. @@ -1913,8 +2094,11 @@ def irecv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proce group_src_rank = get_group_rank(pg, src) return pg.recv([tensor], group_src_rank, tag) + @_exception_logger -def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0) -> None: +def send( + tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0 +) -> None: """ Send a tensor synchronously. @@ -1951,8 +2135,14 @@ def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, t group_dst_rank = get_group_rank(group, dst) group.send([tensor], group_dst_rank, tag).wait() + @_exception_logger -def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: int = 0) -> int: +def recv( + tensor: torch.Tensor, + src: Optional[int] = None, + group: Optional[ProcessGroup] = None, + tag: int = 0, +) -> int: """ Receives a tensor synchronously. @@ -2004,7 +2194,15 @@ def recv(tensor: torch.Tensor, src: Optional[int] = None, group: Optional[Proces class _IllegalWork(Work): def __getattribute__(self, name): - if name in ["is_success", "exception", "wait", "source_rank", "_source_rank", "result", "synchronize"]: + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: raise ValueError(f"Illegal to call {name} on IllegalWork object") @@ -2057,7 +2255,9 @@ def _coalescing_manager( group = group or _get_default_group() op_list = _world.pg_coalesce_state.setdefault(group, []) if op_list: - raise ValueError("ProcessGroup has non-empty op list at the start of coalescing") + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) if device: group._start_coalescing(device) cm = _CoalescingManager() @@ -2212,6 +2412,7 @@ def broadcast(tensor, src, group=None, async_op=False): else: work.wait() + @_exception_logger def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2292,6 +2493,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger @deprecated( "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " @@ -2359,6 +2561,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + @_exception_logger def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): """ @@ -2404,6 +2607,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() + def _object_to_tensor(obj, device, group): f = io.BytesIO() _pickler(f).dump(obj) @@ -2416,7 +2620,9 @@ def _object_to_tensor(obj, device, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) - logger.warning("_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash) + logger.warning( + "_object_to_tensor size: %s hash value: %s", byte_tensor.numel(), hash + ) local_size = torch.LongTensor([byte_tensor.numel()]).to(device) return byte_tensor, local_size @@ -2426,7 +2632,9 @@ def _tensor_to_object(tensor, tensor_size, group): backend = get_backend(group) if backend == Backend.NCCL: hash = torch._C._distributed_c10d._hash_tensors([tensor]) - logger.warning("_tensor_to_object size: %s hash value: %s", tensor.numel(), hash) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) tensor = tensor.cpu() buf = tensor.numpy().tobytes()[:tensor_size] return _unpickler(io.BytesIO(buf)).load() @@ -2709,7 +2917,9 @@ def send_object_list(object_list, dst, group=None, device=None): # sent to this device. current_device = device or _get_pg_default_device(group) # Serialize object_list elements to tensors on src rank. - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) # Send object sizes @@ -2793,7 +3003,9 @@ def recv_object_list(object_list, src=None, group=None, device=None): # case it is not ``None`` we move the size and object tensors to be # received to this device. current_device = device or _get_pg_default_device(group) - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Receive object sizes rank_sizes = recv(object_sizes_tensor, src=src, group=group) @@ -2802,11 +3014,13 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) rank_objects = recv(object_tensor, src=src, group=group) - assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + assert ( + rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -2816,6 +3030,7 @@ def recv_object_list(object_list, src=None, group=None, device=None): object_list[i] = _tensor_to_object(obj_view, obj_size, group) return rank_objects + @_exception_logger def broadcast_object_list(object_list, src=0, group=None, device=None): """ @@ -2892,10 +3107,14 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): my_rank = get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) object_sizes_tensor = torch.cat(size_list) else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) # Broadcast object sizes broadcast(object_sizes_tensor, src=src, group=group) @@ -2912,7 +3131,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): object_tensor = torch.empty( # type: ignore[call-overload] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device + device=current_device, ) broadcast(object_tensor, src=src, group=group) @@ -3000,7 +3219,10 @@ def scatter_object_list( pg_device = _get_pg_default_device(group) if my_rank == src: tensor_list, tensor_sizes = zip( - *[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list] + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] ) tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) @@ -3015,7 +3237,9 @@ def scatter_object_list( broadcast(max_tensor_size, src=src, group=group) # Scatter actual serialized objects - output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8, device=pg_device) + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) scatter( output_tensor, scatter_list=None if my_rank != src else tensor_list, # type: ignore[possibly-undefined] @@ -3033,7 +3257,9 @@ def scatter_object_list( ) # Deserialize back to object - scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group) + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) @_exception_logger @@ -3900,6 +4126,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False else: work.wait() + @_exception_logger def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): """ @@ -4041,15 +4268,18 @@ def _create_process_group_wrapper( wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) return wrapped_pg + # helper function for deterministically hashing a list of ranks def _hash_ranks(ranks: List[int]): return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() + # Takes a list of ranks and computes an integer color def _process_group_color(ranks: List[int]) -> int: # Convert our hash to an int, but avoid negative numbers by shifting a bit. return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) + def _process_group_name(ranks, use_hashed_name): global _world if use_hashed_name: @@ -4061,6 +4291,7 @@ def _process_group_name(ranks, use_hashed_name): _world.group_count += 1 return pg_name + def _get_backend_from_str(backend: Optional[str] = None) -> Backend: # Default to the same backend as the global process group # if backend is not specified. @@ -4070,7 +4301,14 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend: @_time_logger -def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None): +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, +): """ Create a new distributed group. @@ -4137,6 +4375,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local group_desc=group_desc, ) + def _new_group_with_tag( ranks=None, timeout=None, @@ -4144,7 +4383,7 @@ def _new_group_with_tag( pg_options=None, pg_tag=None, use_local_synchronization=False, - group_desc=None + group_desc=None, ): """ Variant of ``new_group`` that exposes tag creation. @@ -4159,7 +4398,6 @@ def _new_group_with_tag( global_rank = default_pg.rank() global_world_size = default_pg.size() - # Default to the same backend as the global process group # if the backend is not specified. if not backend: @@ -4175,7 +4413,9 @@ def _new_group_with_tag( if use_local_synchronization: # MPI backend doesn't have have a way for us to perform a partial sync if backend == Backend.MPI: - raise ValueError("MPI backend doesn't support use_local_synchronization=True") + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) if ranks is not None and get_rank() not in ranks: return None @@ -4217,7 +4457,7 @@ def _new_group_with_tag( pg_options=pg_options, timeout=timeout, pg_tag=pg_tag, - group_desc=group_desc + group_desc=group_desc, ) # Create the global rank to group rank mapping @@ -4246,7 +4486,9 @@ def _new_group_with_tag( world_size = len(ranks) if use_local_synchronization else get_world_size() # Use store based barrier here since barrier() used a bunch of # default devices and messes up NCCL internal state. - _store_based_barrier(global_rank, barrier_store, group_name, world_size, timeout) + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) return pg @@ -4332,16 +4574,20 @@ def new_subgroups( """ if group_size is None: if not torch.cuda.is_available(): - raise ValueError("Default group size only takes effect when CUDA is available." - "If your subgroup using a backend that does not depend on CUDA," - "please pass in 'group_size' correctly.") + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) group_size = torch.cuda.device_count() if group_size <= 0: raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") world_size = get_world_size() if world_size < group_size: - raise ValueError(f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})") + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) if world_size % group_size != 0: raise ValueError("The world size must be divisible by 'group_size'") @@ -4364,10 +4610,7 @@ def new_subgroups( rank = get_rank() if rank in ranks_in_subgroup: cur_subgroup = subgroup - logger.info( - "Rank %s is assigned to subgroup %s", - rank, ranks_in_subgroup - ) + logger.info("Rank %s is assigned to subgroup %s", rank, ranks_in_subgroup) return cur_subgroup, subgroups @@ -4479,8 +4722,13 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro return group return None -def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup: - assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: List[int], stride: int +) -> ProcessGroup: + assert ( + len(ranks) % stride == 0 + ), f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" my_rank = get_rank() my_ranks = None @@ -4505,6 +4753,7 @@ def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) # TODO copy settings and timeout from default PG return _new_group_with_tag(my_ranks, pg_tag=tag) + def _get_group_tag(pg: ProcessGroup) -> str: """Return the tag associated with ``pg``.""" tag = _world.pg_to_tag[pg] @@ -4512,12 +4761,15 @@ def _get_group_tag(pg: ProcessGroup) -> str: tag = tag[5:] return tag + def _get_process_group_name(pg: ProcessGroup) -> str: return _world.pg_names.get(pg, "None") + def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + # This ops are not friendly to TorchDynamo. So, we decide to disallow these ops # in FX graph, allowing them to run them on eager, with torch.compile. dynamo_unsupported_distributed_c10d_ops = [ diff --git a/torch/distributed/examples/memory_tracker_example.py b/torch/distributed/examples/memory_tracker_example.py index cb2ba03777d8f7..e40cfb8b3f5943 100644 --- a/torch/distributed/examples/memory_tracker_example.py +++ b/torch/distributed/examples/memory_tracker_example.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs -import torch import torchvision +import torch from torch.distributed._tools import MemoryTracker diff --git a/torch/distributed/launcher/__init__.py b/torch/distributed/launcher/__init__.py index f0d25f8080c269..fb744a2b93615b 100644 --- a/torch/distributed/launcher/__init__.py +++ b/torch/distributed/launcher/__init__.py @@ -8,7 +8,7 @@ from torch.distributed.launcher.api import ( # noqa: F401 - LaunchConfig, elastic_launch, launch_agent, + LaunchConfig, ) diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 937647f77828f1..a3bcd4073c9baf 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -15,13 +15,18 @@ from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent -from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint from torch.distributed.elastic.utils.logging import get_logger -__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] logger = get_logger(__name__) @@ -212,8 +217,8 @@ def launch_agent( "max_restarts": config.max_restarts, "monitor_interval": config.monitor_interval, "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] - "metrics_cfg": config.metrics_cfg - } + "metrics_cfg": config.metrics_cfg, + }, ) rdzv_parameters = RendezvousParameters( diff --git a/torch/distributed/logging_handlers.py b/torch/distributed/logging_handlers.py index 3c607fe45da771..021ad100f06a89 100644 --- a/torch/distributed/logging_handlers.py +++ b/torch/distributed/logging_handlers.py @@ -9,6 +9,7 @@ import logging from typing import Dict, List + __all__: List[str] = [] _log_handlers: Dict[str, logging.Handler] = { diff --git a/torch/distributed/nn/__init__.py b/torch/distributed/nn/__init__.py index 3ed1b42cbe1582..e15fb517052e4a 100644 --- a/torch/distributed/nn/__init__.py +++ b/torch/distributed/nn/__init__.py @@ -1,4 +1,7 @@ import torch + +from .functional import * # noqa: F403 + + if torch.distributed.rpc.is_available(): from .api.remote_module import RemoteModule -from .functional import * # noqa: F403 diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index de8a15dd65da5b..5583da8c3e8d4e 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -21,14 +21,15 @@ import torch import torch.distributed.rpc as rpc -from torch import Tensor, device, dtype, nn -from torch.distributed.nn.jit import instantiator +from torch import device, dtype, nn, Tensor from torch.distributed import _remote_device +from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.internal import _internal_rpc_pickler from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle + __all__ = ["RemoteModule"] _grad_t = Union[Tuple[Tensor, ...], Tensor] @@ -120,7 +121,6 @@ def _raise_not_supported(name: str) -> None: class _RemoteModule(nn.Module): - def __new__(cls, *args, **kwargs): # Use __new__ for logging purposes. torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") @@ -370,7 +370,10 @@ def register_forward_pre_hook( # type: ignore[return] self, hook: Union[ Callable[[T, Tuple[Any, ...]], Optional[Any]], - Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + Callable[ + [T, Tuple[Any, ...], Dict[str, Any]], + Optional[Tuple[Any, Dict[str, Any]]], + ], ], prepend: bool = False, with_kwargs: bool = False, @@ -405,10 +408,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: ) def named_parameters( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Parameter]]: _raise_not_supported(self.named_parameters.__name__) @@ -416,10 +416,7 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[ret _raise_not_supported(self.buffers.__name__) def named_buffers( # type: ignore[return] - self, - prefix: str = "", - recurse: bool = True, - remove_duplicate: bool = True + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[Tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) @@ -464,7 +461,11 @@ def _prepare_init(self, remote_device_str: str) -> bool: assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." remote_device = _remote_device(remote_device_str) - self.on = remote_device.worker_name() if remote_device.worker_name() is not None else remote_device.rank() + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) self.device = str(remote_device.device()) agent = rpc._get_current_rpc_agent() # If the device map of the remote worker is set, diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index e90a78a69324b3..110df578552a59 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -2,11 +2,13 @@ import torch import torch.distributed as dist from torch.autograd import Function + # The two imports below are not always available depending on the # USE_DISTRIBUTED compile flag. Make sure they raise import error # if we're trying to use them. from torch.distributed import group, ReduceOp + def broadcast(tensor, src, group=group.WORLD): """ Broadcasts the tensor to the whole group. @@ -116,6 +118,7 @@ def all_gather(tensor, group=group.WORLD): """ return _AllGather.apply(group, tensor) + def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): """ Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. @@ -340,6 +343,7 @@ def backward(ctx, *grad_outputs): gx = torch.sum(torch.stack(gxs), dim=0) return (None, gx) + class _AllGatherBase(Function): @staticmethod def forward(ctx, output_tensor, input_tensor, group): @@ -354,16 +358,19 @@ def backward(ctx, grad_output): out_size = list(grad_output.size()) if out_size[0] % world_size != 0: raise RuntimeError( - f'Tensor with dimensions: {out_size} does ' - f'not have first dimension divisible by world_size: {world_size}' + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" ) out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) - gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) else: raise RuntimeError("Backend not supported!") return (None, gx, None) + class _AlltoAll(Function): @staticmethod def forward(ctx, group, out_tensor_list, *tensors): @@ -391,7 +398,9 @@ def forward(ctx, group, out_tensor_list, *tensors): @staticmethod def backward(ctx, *grad_outputs): tensor_list = [ - torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) for size in ctx.input_tensor_size_list ] return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) @@ -415,7 +424,9 @@ def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): @staticmethod def backward(ctx, grad_output): - tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype) + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) return (None, None, None, None) + ( _AlltoAllSingle.apply( ctx.group, diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7d0aede8943ebb..81ddeb8bfe0ad8 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -5,7 +5,7 @@ import operator from collections import defaultdict from enum import Enum -from inspect import Parameter, signature, Signature +from inspect import Parameter, Signature, signature from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -21,6 +21,7 @@ ) from torch.fx.node import map_aggregate from torch.fx.passes.split_module import split_module + from ._backward import _null_coalesce_accumulate, stage_backward from ._unflatten import _outline_submodules from ._utils import PipeInfo @@ -1176,7 +1177,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): predecessor_module = getattr(predecessor_module, atom) except AttributeError as e: raise AttributeError( - f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}' + f"Specified target {qualname} referenced " + f'nonexistent module {".".join(atoms[: i + 1])}' ) from e mod_to_wrap = getattr(predecessor_module, atoms[-1]) diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 18b3191add5b64..5b1843a33f6fd6 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -8,6 +8,7 @@ ) from .stage import build_stage, PipelineStage + __all__ = [ "Pipe", "pipe_split", diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index da664f7408bb20..bdb1974b1b37c3 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -47,7 +47,7 @@ def __init__(self, remote_device: Union[str, torch.device]): else: raise ValueError(PARSE_ERROR) else: - raise TypeError(f'Invalid type for remote_device: {type(remote_device)}') + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") # Do some basic sanity check (no empty string) if self._worker_name is not None and not self._worker_name: @@ -96,18 +96,18 @@ def device(self) -> torch.device: def __repr__(self): if self._device is not None: if self._worker_name is not None: - return f'{self._worker_name}/{self._device}' + return f"{self._worker_name}/{self._device}" elif self._rank is not None: - return f'rank:{self._rank}/{self._device}' + return f"rank:{self._rank}/{self._device}" else: return str(self._device) else: if self._worker_name is not None: - return f'{self._worker_name}' + return f"{self._worker_name}" elif self._rank is not None: - return f'{self._rank}' + return f"{self._rank}" else: - raise RuntimeError('Invalid state!') + raise RuntimeError("Invalid state!") def __eq__(self, other): if not isinstance(other, _remote_device): @@ -122,8 +122,5 @@ def __eq__(self, other): return False - def __hash__(self): - return hash(self._worker_name) ^ \ - hash(self._device) ^ \ - hash(self._rank) + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index e3266cb238acad..a944a75271b0d6 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -10,7 +10,7 @@ import os import sys from datetime import timedelta -from typing import Dict, Optional, Callable, Iterator, Tuple +from typing import Callable, Dict, Iterator, Optional, Tuple from torch.distributed import FileStore, PrefixStore, Store, TCPStore @@ -21,6 +21,7 @@ __all__ = ["register_rendezvous_handler", "rendezvous"] + def register_rendezvous_handler(scheme, handler): """ Register a new rendezvous handler. @@ -47,16 +48,17 @@ def register_rendezvous_handler(scheme, handler): """ global _rendezvous_handlers if scheme in _rendezvous_handlers: - raise RuntimeError( - f"Rendezvous handler for {scheme}:// already registered" - ) + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") _rendezvous_handlers[scheme] = handler # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} def _query_to_dict(query: str) -> Dict[str, str]: - return {pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&")))} + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: @@ -152,7 +154,9 @@ def _torchelastic_use_agent_store() -> bool: return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) -def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True) -> Store: +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: """ Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. @@ -183,7 +187,13 @@ def _create_c10d_store(hostname, port, rank, world_size, timeout, use_libuv=True else: start_daemon = rank == 0 return TCPStore( - hostname, port, world_size, start_daemon, timeout, multi_tenant=True, use_libuv=use_libuv + hostname, + port, + world_size, + start_daemon, + timeout, + multi_tenant=True, + use_libuv=use_libuv, ) @@ -208,7 +218,9 @@ def _error(msg): assert result.hostname is not None - store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) @@ -250,12 +262,13 @@ def _get_env_or_raise(env_var: str) -> str: else: world_size = int(_get_env_or_raise("WORLD_SIZE")) - master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) use_libuv = _get_use_libuv_from_query_dict(query_dict) - store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv) + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) yield (store, rank, world_size) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 5654693f3dfca9..aa34891d1ecd27 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -397,9 +397,9 @@ def main(): import os import sys import uuid -import importlib.metadata as metadata -from argparse import REMAINDER, ArgumentParser -from typing import Callable, List, Tuple, Type, Union, Optional, Set +from argparse import ArgumentParser, REMAINDER +from importlib import metadata +from typing import Callable, List, Optional, Set, Tuple, Type, Union import torch from torch.distributed.argparse_util import check_env, env @@ -408,9 +408,9 @@ def main(): from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger -from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.utils.backend_registration import _get_custom_mod_func -import torch.multiprocessing + logger = get_logger(__name__) @@ -693,21 +693,26 @@ def determine_local_world_size(nproc_per_node: str): if torch.cuda.is_available(): num_proc = torch.cuda.device_count() device_type = "gpu" - elif hasattr(torch, torch._C._get_privateuse1_backend_name()) and \ - _get_custom_mod_func("is_available")(): + elif ( + hasattr(torch, torch._C._get_privateuse1_backend_name()) + and _get_custom_mod_func("is_available")() + ): num_proc = _get_custom_mod_func("device_count")() device_type = torch._C._get_privateuse1_backend_name() else: num_proc = os.cpu_count() device_type = "cpu" else: - raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e logger.info( - "Using nproc_per_node=%s," - " setting to %s since the instance " - "has %s %s", - nproc_per_node, num_proc, os.cpu_count(), device_type + "Using nproc_per_node=%s," " setting to %s since the instance " "has %s %s", + nproc_per_node, + num_proc, + os.cpu_count(), + device_type, ) return num_proc @@ -753,9 +758,13 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: logs_specs_cls = entrypoint_list[0].load() if logs_specs_cls is None: - raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key") + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) - logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)) + logging.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) else: logs_specs_cls = DefaultLogsSpecs @@ -768,7 +777,11 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str assert 0 < min_nodes <= max_nodes assert args.max_restarts >= 0 - if hasattr(args, "master_addr") and args.rdzv_backend != "static" and not args.rdzv_endpoint: + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): logger.warning( "master_addr is only used for static rdzv_backend and when rdzv_endpoint " "is not specified." @@ -784,7 +797,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str "please further tune the variable for optimal performance in " "your application as needed. \n" "*****************************************", - omp_num_threads + omp_num_threads, ) # This env variable will be passed down to the subprocesses os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) @@ -888,7 +901,9 @@ def run(args): "--rdzv-endpoint=%s " "--rdzv-id=%s\n" "**************************************\n", - args.rdzv_backend, args.rdzv_endpoint, args.rdzv_id + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, ) config, cmd, cmd_args = config_from_args(args) diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index f13d066415015b..1a0b849f955d12 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -21,6 +21,7 @@ from torch.nn.parallel.scatter_gather import _is_namedtuple from torch.nn.utils.rnn import PackedSequence + __all__ = [] # type: ignore[var-annotated] From d9c294c6726ec833406dfaf1a2cdee77c4a5785d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Tue, 18 Jun 2024 22:06:53 +0000 Subject: [PATCH 159/171] [Inductor] Fix arguments passed to triton kernel launch hooks (#128732) `binary.launch_enter_hook` is treated as an instance method and will add a `self` argument to the hooks. `CompiledKernel.launch_enter_hook` is a static method, which matches the hook calling convention of profilers (i.e., a single `LazyDict` argument only). Pull Request resolved: https://github.com/pytorch/pytorch/pull/128732 Approved by: https://github.com/shunting314, https://github.com/bertmaher --- test/inductor/test_profiler.py | 4 ++-- torch/_inductor/runtime/triton_heuristics.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index d2ff71dd73bb69..9d0270a9aae8d2 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -158,10 +158,10 @@ def test_inductor_profiling_triton_hooks(self): hooks_called = {"enter": False, "exit": False} - def launch_enter_hook(*args): + def launch_enter_hook(lazy_dict): hooks_called["enter"] = True - def launch_exit_hook(*args): + def launch_exit_hook(lazy_dict): hooks_called["exit"] = True CompiledKernel.launch_enter_hook = launch_enter_hook diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 5396ccf3e70d53..82a25392b5e950 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -50,6 +50,7 @@ if triton is not None: from triton import Config + from triton.compiler import CompiledKernel from triton.runtime.autotuner import OutOfResources from triton.runtime.jit import KernelInterface @@ -453,8 +454,8 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): scope = { "grid_meta": cfg.kwargs, "bin": binary, - "launch_enter_hook": binary.launch_enter_hook, - "launch_exit_hook": binary.launch_exit_hook, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, "metadata": binary.packed_metadata if hasattr(binary, "packed_metadata") else binary.metadata, From ac5f565fa7010bd77b9e779415e8709d347234b6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 18 Jun 2024 11:41:03 -0700 Subject: [PATCH 160/171] [FSDP2] Added `set_post_optim_event` (#128975) This PR adds `set_post_optim_event` that allows power users to provide their own CUDA event that is recorded after the optimizer step for the FSDP root module to wait the all-gather streams on. ``` def set_post_optim_event(self, event: torch.cuda.Event) -> None: ``` By default, the root would have the all-gather streams wait on the current stream (`wait_stream`), which may introduce false dependencies if there is unrelated computation after the optimizer step and before the wait. For example, this pattern can appear in recommendation models. To avoid those false dependencies while preserving the correctness guarantee, we provide this API so that the user can provide their own CUDA event to wait the all-gather streams on. We include both correctness test (`test_fully_shard_training.py`) and overlap test (`test_fully_shard_overlap.py`). --- One possible way to use the API is to register a post-step hook on the optimizer. For example: https://github.com/pytorch/pytorch/blob/12e8d1399b979b45d16f0934017f742d01ab2b8d/test/distributed/_composable/fsdp/test_fully_shard_training.py#L546-L552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128975 Approved by: https://github.com/sanketpurandare, https://github.com/weifengpy ghstack dependencies: #128884 --- .../fsdp/test_fully_shard_overlap.py | 82 ++++++++++++++++--- .../fsdp/test_fully_shard_training.py | 41 ++++++++++ .../_composable/fsdp/_fsdp_state.py | 14 +++- .../_composable/fsdp/fully_shard.py | 19 +++++ 4 files changed, 142 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index 99823883abfbb7..1fca6c3f3c5a0d 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import functools from typing import Callable import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor.experimental import implicit_replication from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( FSDPTest, @@ -23,15 +25,6 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_fully_shard_training_overlap(self): - class LinearWithSleep(nn.Module): - def __init__(self, dim: int, sleep_ms: int): - super().__init__() - self.weight = nn.Parameter(torch.randn((dim, dim))) - self.sleep_ms = sleep_ms - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) - torch.manual_seed(42) # Use non-trivial comm. time but still shorter than compute time @@ -44,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: fully_shard(model, reshard_after_forward=True) orig_all_gather_into_tensor = dist.all_gather_into_tensor - orig_reduce_scatter = dist.reduce_scatter_tensor + orig_reduce_scatter_tensor = dist.reduce_scatter_tensor comm_stream = torch.cuda.Stream() def delay_collective(): @@ -61,7 +54,7 @@ def delayed_all_gather(*args, **kwargs): def delayed_reduce_scatter(*args, **kwargs): delay_collective() - return orig_reduce_scatter(*args, **kwargs) + return orig_reduce_scatter_tensor(*args, **kwargs) inp = torch.randn((2, dim), device="cuda") loss = model(inp).sum() # warmup CUDA and allocator @@ -92,6 +85,63 @@ def fwd_bwd(): ) self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time) + @skip_if_lt_x_gpu(2) + def test_fully_shard_post_optim_event_overlap(self): + torch.manual_seed(42) + + # Use non-trivial comm. time but still shorter than compute time + dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10) + # Define the model to have a high-compute linear followed by a + # low-compute linear, where only the low-compute linear uses FSDP + model = nn.Sequential( + LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim) + ).cuda() + fully_shard(model[1], reshard_after_forward=False) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + orig_all_gather_into_tensor = dist.all_gather_into_tensor + + def delayed_all_gather(*args, **kwargs): + torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms())) + return orig_all_gather_into_tensor(*args, **kwargs) + + inp = torch.randn((2, dim), device="cuda") + + def run_train_steps(num_iters: int, use_post_optim_event: bool): + for _ in range(num_iters): + optim.zero_grad() + with patch_all_gather(delayed_all_gather): + loss = model(inp).sum() + loss.backward() + with implicit_replication(): + optim.step() + if use_post_optim_event: + post_optim_event = torch.cuda.current_stream().record_event() + model[1].set_post_optim_event(post_optim_event) + + run_train_steps(1, False) # warmup CUDA and allocator + num_iters = 5 + baseline_time = self._time_fn( + functools.partial(run_train_steps, num_iters, False) + ) + test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True)) + + buffer_ms = 4 # CPU delays and copies + # Baseline: FSDP all-gather is exposed since the FSDP module waits for + # the current stream and hence the high-compute linear + self.assertLessEqual( + baseline_time, + num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms), + ) + # Test: FSDP all-gather is overlapped with the high-compute linear + # since the FSDP module only waits for the post-optim event (except on + # the 1st iteration when no event has been recorded) + expected_test_time = ( + num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms + ) + self.assertLessEqual(test_time, expected_test_time) + self.assertGreater(baseline_time, expected_test_time) + def _time_fn(self, fn: Callable): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -123,5 +173,15 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input, grad_weight, None +class LinearWithSleep(nn.Module): + def __init__(self, dim: int, sleep_ms: int): + super().__init__() + self.weight = nn.Parameter(torch.randn((dim, dim))) + self.sleep_ms = sleep_ms + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms)) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 3dbaa652437940..abc579b40d6246 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -532,6 +532,47 @@ def test_explicit_prefetching(self): _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + def test_post_optim_event(self): + torch.manual_seed(42) + model_args = ModelArgs(dropout_p=0.0) + model = Transformer(model_args) + ref_model = replicate(copy.deepcopy(model).cuda()) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + for layer in itertools.chain(model.layers, [model]): + fully_shard(layer) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + def step_post_hook( + fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs + ) -> None: + post_optim_event = torch.cuda.current_stream().record_event() + fsdp_module.set_post_optim_event(post_optim_event) + + optim.register_step_post_hook(functools.partial(step_post_hook, model)) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") + # Track all losses and check for equality at the end to avoid a CPU + # sync point after each iteration + ref_losses: List[torch.Tensor] = [] + losses: List[torch.Tensor] = [] + for iter_idx in range(10): + ref_optim.zero_grad() + ref_losses.append(ref_model(inp).sum()) + ref_losses[-1].backward() + ref_optim.step() + for iter_idx in range(10): + optim.zero_grad() + losses.append(model(inp).sum()) + losses[-1].backward() + optim.step() + # Sleep after the optimizer step to allow CPU to run ahead into the + # next iteration's forward, exercising the post-optim stream sync + torch.cuda._sleep(int(25 * get_cycles_per_ms())) + for ref_loss, loss in zip(ref_losses, losses): + self.assertEqual(ref_loss, loss) + class TestFullyShard1DTrainingCompose(FSDPTest): @property diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index c6cdb2b29880bf..f04e6f6d09292d 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -36,6 +36,9 @@ def __init__(self): self.post_backward_final_callback_queued: bool = False # Whether to finalize backward in this backward's final callback self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.cuda.Event] = None def disable_if_config_true(func): @@ -84,9 +87,14 @@ def _root_pre_forward( self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): # Wait for optimizer before implicitly prefetched all-gathers - current_stream = torch.cuda.current_stream() - self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) - self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = torch.cuda.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) if self._device.type == "cuda": with torch.profiler.record_function("FSDP::inputs_to_device"): args_tuple, kwargs_tuple = _to_kwargs( diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index e8ab3466118bc7..88180f40f792c1 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -309,6 +309,25 @@ def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: module._get_fsdp_state() for module in modules ] + def set_post_optim_event(self, event: torch.cuda.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.cuda.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") From cb5e9183c6056a7f929a12f574372e87e879d29e Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 19 Jun 2024 00:05:50 +0000 Subject: [PATCH 161/171] [Caffe2] [2/N] Remove Caffe2 from tests (#128911) Follows #128675 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128911 Approved by: https://github.com/titaiwangms, https://github.com/r-barnes --- test/jit/test_tracer.py | 45 ----------- test/onnx/pytorch_test_common.py | 4 +- test/onnx/test_operators.py | 27 ------- test/quantization/core/test_quantized_op.py | 47 ------------ test/test_determination.py | 7 -- test/test_public_bindings.py | 1 - test/test_tensorboard.py | 83 +-------------------- test/test_torch.py | 17 +---- 8 files changed, 4 insertions(+), 227 deletions(-) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 5da8ab61c5b3c9..d5ef39ba0c8b48 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -911,51 +911,6 @@ def forward(self, x): self.assertEqual(len(list(g.inputs())), 2) FileCheck().check("mul").check("add").run(str(g)) - def test_trace_c10_ops(self): - try: - _ = torch.ops._caffe2.GenerateProposals - except AttributeError: - self.skipTest("Skip the test since c2 ops are not registered.") - - class MyModel(torch.nn.Module): - def forward(self, scores, bbox_deltas, im_info, anchors): - a, b = torch.ops._caffe2.GenerateProposals( - (scores), - (bbox_deltas), - (im_info), - (anchors), - 2.0, - 6000, - 300, - 0.7, - 16, - True, - -90, - 90, - 1.0, - True, - ) - return a, b - - model = MyModel() - A = 4 - H = 10 - W = 8 - img_count = 3 - scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - bbox_deltas = torch.linspace( - 0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32 - ) - bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - im_info = torch.ones(img_count, 3, dtype=torch.float32) - anchors = torch.ones(A, 4, dtype=torch.float32) - inputs = (scores, bbox_deltas, im_info, anchors) - traced_model = torch.jit.trace(model, inputs) - self.assertEqual(traced_model(*inputs), model(*inputs)) - self.assertExportImportModule( - traced_model, (scores, bbox_deltas, im_info, anchors) - ) - def run_ge_tests(self, optimize, use_cuda): with enable_profiling_mode_for_profiling_tests(): with torch.jit.optimized_execution(optimize): diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 6fdbf4e92839c3..3b66750f45d8d8 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -340,8 +340,8 @@ def inner(self, *args, **kwargs): # skips tests for opset_versions listed in unsupported_opset_versions. -# if the caffe2 test cannot be run for a specific version, add this wrapper -# (for example, an op was modified but the change is not supported in caffe2) +# if the PyTorch test cannot be run for a specific version, add this wrapper +# (for example, an op was modified but the change is not supported in PyTorch) def skipIfUnsupportedOpsetVersion(unsupported_opset_versions): def skip_dec(func): @functools.wraps(func) diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 87ec424cf65d57..b3c75486450a52 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -873,33 +873,6 @@ def test_cumsum(self): x = torch.randn(2, 3, 4, requires_grad=True) self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11) - # Github Issue: https://github.com/pytorch/pytorch/issues/71095 - # def test_c2_op(self): - # class MyModel(torch.nn.Module): - # def __init__(self): - # super().__init__() - # - # def forward(self, scores, bbox_deltas, im_info, anchors): - # a, b = torch.ops._caffe2.GenerateProposals( - # (scores), (bbox_deltas), (im_info), (anchors), - # 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, - # ) - # return a, b - # - # model = MyModel() - # A = 4 - # H = 10 - # W = 8 - # img_count = 3 - # scores = torch.ones(img_count, A, H, W, dtype=torch.float32) - # bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, - # dtype=torch.float32) - # bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) - # im_info = torch.ones(img_count, 3, dtype=torch.float32) - # anchors = torch.ones(A, 4, dtype=torch.float32) - # inputs = (scores, bbox_deltas, im_info, anchors) - # self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0}) - def test_dict(self): class MyModel(torch.nn.Module): def forward(self, x_in): diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 2e606938192dd2..25b062a7ab13fc 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4457,54 +4457,7 @@ def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimize self.assertEqual(unpacked_weight.q_per_channel_scales(), qweight.q_per_channel_scales()) self.assertEqual(unpacked_weight.q_per_channel_zero_points(), qweight.q_per_channel_zero_points()) - # compare against C2 to ensure numerical equivalency. - from caffe2.python import core, workspace - conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == torch.float32 else "HalfFloatToFused8BitRowwiseQuantized" - reverse_conversion_op = None - if bit_rate == 4: - conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused4BitRowwiseQuantized" - reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat" - elif bit_rate == 2: - conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused2BitRowwiseQuantized" - reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat" - - def get_c2_weights(weights, engine_str): - workspace.ResetWorkspace() - - workspace.FeedBlob("weights", weights) - workspace.RunOperatorOnce( - core.CreateOperator( - conversion_op, ["weights"], ["quantized_weights"], engine=engine_str - ) - ) - emb_q = workspace.FetchBlob("quantized_weights") - if bit_rate == 4 or bit_rate == 2: - workspace.RunOperatorOnce( - core.CreateOperator( - reverse_conversion_op, ["quantized_weights"], ["dequantized_weights"] - ) - ) - dequantized_data = torch.from_numpy(workspace.FetchBlob("dequantized_weights")) - else: - dequantized_data = torch.ops._caffe2.Fused8BitRowwiseQuantizedToFloat( - torch.tensor(emb_q) - ) - return torch.from_numpy(emb_q), dequantized_data - - if optimized_qparams: - engine = "GREEDY" - else: - engine = "" - - # C2 quantization needs the memory format of Tensor to be `continuous`, otherwise it will - # throw exceptions. torch.clone() will make the memory format to be `continuous` - c2_copy = torch.clone(weights) - w_packed_c2, w_unpacked_c2 = get_c2_weights(c2_copy, engine) - # Compare packed weights against C2. - np.testing.assert_allclose(w_packed.numpy(), w_packed_c2.numpy(), atol=1e-6, rtol=1e-6) - # Compare unpacked weights against C2 - np.testing.assert_allclose(w_unpacked.numpy(), w_unpacked_c2.numpy(), atol=1e-6, rtol=1e-6) def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, diff --git a/test/test_determination.py b/test/test_determination.py index 50cc2fa9975da0..09a67de45dc694 100644 --- a/test/test_determination.py +++ b/test/test_determination.py @@ -121,13 +121,6 @@ def test_torch_file(self): ], ) - def test_caffe2_file(self): - """Caffe2 files trigger dependent tests""" - self.assertEqual(self.determined_tests(["caffe2/python/brew_test.py"]), []) - self.assertEqual( - self.determined_tests(["caffe2/python/context.py"]), self.TESTS - ) - def test_new_folder(self): """New top-level Python folder triggers all tests""" self.assertEqual(self.determined_tests(["new_module/file.py"]), self.TESTS) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 8ab2ac1f511f02..65a5bf90b9f93b 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -342,7 +342,6 @@ def test_modules_can_be_imported(self): "torch.testing._internal.distributed.rpc.rpc_test", "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", - "torch.utils.tensorboard._caffe2_graph", "torch._inductor.codegen.cuda.cuda_template", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.runtime.triton_helpers", diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index 3ce2ab2a172c81..1e79a2bf910cec 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -23,15 +23,6 @@ HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") -TEST_CAFFE2 = True -try: - import caffe2.python.caffe2_pybind11_state as _caffe2_pybind11_state # noqa: F401 - from caffe2.python import brew, cnn, core, workspace - from caffe2.python.model_helper import ModelHelper -except ImportError: - TEST_CAFFE2 = False -skipIfNoCaffe2 = unittest.skipIf(not TEST_CAFFE2, "no caffe2") - TEST_MATPLOTLIB = True try: import matplotlib @@ -48,7 +39,6 @@ parametrize, TestCase, run_tests, - TEST_WITH_ASAN, TEST_WITH_CROSSREF, IS_WINDOWS, IS_MACOS, @@ -94,8 +84,6 @@ def tearDown(self): from torch.utils.tensorboard._pytorch_graph import graph from google.protobuf import text_format from PIL import Image -if TEST_TENSORBOARD and TEST_CAFFE2: - from torch.utils.tensorboard import _caffe2_graph as c2_graph class TestTensorBoardPyTorchNumpy(BaseTestCase): def test_pytorch_np(self): @@ -754,80 +742,11 @@ def test_scalar(self): res = make_np(np.int64(100000000000)) self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - @skipIfNoCaffe2 - def test_caffe2_np(self): - workspace.FeedBlob("testBlob", tensor_N(shape=(1, 3, 64, 64))) - self.assertIsInstance(make_np('testBlob'), np.ndarray) - - @skipIfNoCaffe2 - def test_caffe2_np_expect_fail(self): - with self.assertRaises(RuntimeError): - res = make_np('This_blob_does_not_exist') - def test_pytorch_np_expect_fail(self): with self.assertRaises(NotImplementedError): res = make_np({'pytorch': 1.0}) - @skipIfNoCaffe2 - @unittest.skipIf(TEST_WITH_ASAN, "Caffe2 failure with ASAN") - def test_caffe2_simple_model(self): - model = ModelHelper(name="mnist") - # how come those inputs don't break the forward pass =.=a - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - - with core.NameScope("conv1"): - conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5) - # Image size: 24 x 24 -> 12 x 12 - pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) - # Image size: 12 x 12 -> 8 x 8 - conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) - # Image size: 8 x 8 -> 4 x 4 - pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) - with core.NameScope("classifier"): - # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size - fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) - relu = brew.relu(model, fc3, fc3) - pred = brew.fc(model, relu, 'pred', 500, 10) - softmax = brew.softmax(model, pred, 'softmax') - xent = model.LabelCrossEntropy([softmax, "label"], 'xent') - # compute the expected loss - loss = model.AveragedLoss(xent, "loss") - model.net.RunAllOnMKL() - model.param_init_net.RunAllOnMKL() - model.AddGradientOperators([loss], skip=1) - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) - - @skipIfNoCaffe2 - def test_caffe2_simple_cnnmodel(self): - model = cnn.CNNModelHelper("NCHW", name="overfeat") - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(int)) - with core.NameScope("conv1"): - conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4) - relu1 = model.Relu(conv1, conv1) - pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) - with core.NameScope("classifier"): - fc = model.FC(pool1, "fc", 4096, 1000) - pred = model.Softmax(fc, "pred") - xent = model.LabelCrossEntropy([pred, "label"], "xent") - loss = model.AveragedLoss(xent, "loss") - - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) + class TestTensorProtoSummary(BaseTestCase): @parametrize( diff --git a/test/test_torch.py b/test/test_torch.py index f252ddf4a5745d..86844c77faf4aa 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -41,7 +41,7 @@ skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, + bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( @@ -8632,21 +8632,6 @@ def test_allow_tensor_metadata_change(self): a = torch.ones(2, 3) # Metadata changes are allowed on view tensors that are created from detach(). - @skipIfNotRegistered("LayerNorm", "Skipping as LayerNorm is not registered") - def test_c10_layer_norm(self): - # test that we can call c10 ops and they return a reasonable result - X = torch.rand(5, 5, dtype=torch.float) - weight = torch.rand(*X.size()[1:], dtype=torch.float) - bias = torch.rand(*X.size()[1:], dtype=torch.float) - epsilon = 1e-4 - - expected_norm = torch.nn.functional.layer_norm( - X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) - actual_norm, actual_mean, actual_stdev = \ - torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( - weight), torch.tensor(bias), 1, epsilon, True) - torch.testing.assert_close(expected_norm, actual_norm) - def test_memory_format(self): def test_helper(x, memory_format): y = x.contiguous(memory_format=memory_format) From c5e0b844847c5c34ee824b0de2adeda85ce64133 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:24 -0700 Subject: [PATCH 162/171] [dynamo][trace_rules] Remove incorrectly classified Ingraph functions (#128428) Co-authored-by: Laith Sakka Pull Request resolved: https://github.com/pytorch/pytorch/pull/128428 Approved by: https://github.com/yanboliang, https://github.com/mlazos --- test/dynamo/test_repros.py | 2 +- torch/_dynamo/trace_rules.py | 28 ---------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index dbcb259241fcbd..2329ab305e763c 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1674,7 +1674,7 @@ def test_issue175(self): self.assertEqual(cnt.frame_count, 1) self.assertEqual( - 18 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count + 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count ) def test_exec_import(self): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index b5b12435a931a7..abbef02e63c682 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2669,26 +2669,6 @@ "torch.nn._reduction.legacy_get_enum", "torch.nn._reduction.legacy_get_string", "torch.nn.factory_kwargs", - "torch.nn.functional._adaptive_max_pool1d", - "torch.nn.functional._adaptive_max_pool2d", - "torch.nn.functional._adaptive_max_pool3d", - "torch.nn.functional._canonical_mask", - "torch.nn.functional._fractional_max_pool2d", - "torch.nn.functional._fractional_max_pool3d", - "torch.nn.functional._get_softmax_dim", - "torch.nn.functional._in_projection_packed", - "torch.nn.functional._in_projection", - "torch.nn.functional._is_integer", - "torch.nn.functional._max_pool1d", - "torch.nn.functional._max_pool2d", - "torch.nn.functional._max_pool3d", - "torch.nn.functional._mha_shape_check", - "torch.nn.functional._no_grad_embedding_renorm_", - "torch.nn.functional._none_or_dtype", - "torch.nn.functional._threshold", - "torch.nn.functional._unpool_output_size", - "torch.nn.functional._verify_batch_size", - "torch.nn.functional._verify_spatial_size", "torch.nn.functional.adaptive_avg_pool2d", "torch.nn.functional.adaptive_avg_pool3d", "torch.nn.functional.adaptive_max_pool1d_with_indices", @@ -2786,15 +2766,7 @@ "torch.nn.grad.conv2d_weight", "torch.nn.grad.conv3d_input", "torch.nn.grad.conv3d_weight", - "torch.nn.modules.activation._arg_requires_grad", - "torch.nn.modules.activation._check_arg_device", "torch.nn.modules.activation._is_make_fx_tracing", - "torch.nn.modules.container._addindent", - "torch.nn.modules.transformer._detect_is_causal_mask", - "torch.nn.modules.transformer._generate_square_subsequent_mask", - "torch.nn.modules.transformer._get_activation_fn", - "torch.nn.modules.transformer._get_clones", - "torch.nn.modules.transformer._get_seq_len", "torch.nn.modules.utils._list_with_default", "torch.nn.modules.utils._ntuple", "torch.nn.modules.utils._quadruple", From 670b94c9c826756495b9e1ca34be1d43756d5296 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Jun 2024 13:14:25 -0700 Subject: [PATCH 163/171] [inductor][mkldnn] Use floats instead of ints for pattern matcher test (#128484) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128484 Approved by: https://github.com/mlazos ghstack dependencies: #128428 --- test/inductor/test_mkldnn_pattern_matcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 810c22d037c548..a80d7239876028 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -37,7 +37,8 @@ torch.nn.Tanh(): 2, torch.nn.Hardswish(): 6, torch.nn.LeakyReLU(0.1, inplace=False): 4, - torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3, + # Use floats for min/max, otherwise they can get converted to symints + torch.nn.Hardtanh(min_val=-0.5, max_val=4.0, inplace=False): 3, torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3, torch.nn.GELU(approximate="none"): 6, torch.nn.GELU(approximate="tanh"): 10, From 99f042d336b53844b509406f1ecf78cb6f5e5714 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:21:21 +0000 Subject: [PATCH 164/171] Revert "Forward fix to skip ROCm tests for #122836 (#128891)" This reverts commit 4061b3b8225f522ae0ed6db00111441e7d3cc3d5. Reverted https://github.com/pytorch/pytorch/pull/128891 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128891#issuecomment-2177291249)) --- test/test_nestedtensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index fa33a13ed495db..6b9b8f3be45d56 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5470,7 +5470,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_preserves_metadata_cache(self, device, dtype): # shape (B, *, D) nt = random_nt_from_dims( @@ -5501,7 +5500,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 @@ -5538,7 +5536,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_dynamic_min_seq_len(self, device, dtype): # shape (B, *, D) # min seq len: 7 @@ -5575,7 +5572,6 @@ def f(nt): ) @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - @skipCUDAIfRocm def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): # shape (B, *, D) # max seq len: 18 From 35c78668b408046e032a1e025b01250875959cc6 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 18 Jun 2024 13:37:50 -0700 Subject: [PATCH 165/171] Improve the debugging message for when foreach mta_called (#128991) The hope that lives in this PR: I am currently trying to debug why the foreach tests are so flaky. It looks like every flaky test falls under this pattern: - a test is flaky due to the mta_called assertion, which gathers data from the profiler regarding whether the multi_tensor_apply_kernel has been called. - then, a later test fails deterministically, usually failing to compare two results. ``` ================== 1 failed, 241 deselected, 2 rerun in 1.76s ================== Got exit code 1 Stopping at first consistent failure The following tests failed and then succeeded when run in a new process ['test/test_foreach.py::TestForeachCUDA::test_binary_op_float_inf_nan__foreach_add_cuda_bfloat16'] The following tests failed consistently: ['test/test_foreach.py::TestForeachCUDA::test_binary_op_list_error_cases__foreach_add_cuda_bfloat16'] ``` So my suspicion is that the first causes the second, but what causes the first? Idk! So it would be nice to have the error message tell us what the profiler actually saw in case it's getting muddled. This change would help mostly because I have not been able to repro this flakiness locally. Also undo the useless changes in #128220 which are actually redundant as Joel and I realized that we set the seed during the setUp of every test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128991 Approved by: https://github.com/clee2000 --- test/test_foreach.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index 567d09cff02d78..99d4cbe5ec003a 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -90,7 +90,7 @@ def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): mta_called = any("multi_tensor_apply_kernel" in k for k in keys) assert mta_called == ( expect_fastpath and (not zero_size) - ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}" + ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" else: actual = self.func(*inputs, **kwargs) if self.is_inplace: @@ -205,7 +205,6 @@ def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_parity(self, device, dtype, op, noncontiguous, inplace): - torch.manual_seed(2024) if inplace: _, _, func, ref = self._get_funcs(op) else: @@ -585,7 +584,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125035", ) def test_binary_op_list_error_cases(self, device, dtype, op): - torch.manual_seed(202406) foreach_op, foreach_op_, ref, ref_ = ( op.method_variant, op.inplace_variant, @@ -680,7 +678,6 @@ def test_binary_op_list_error_cases(self, device, dtype, op): "failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125775", ) def test_binary_op_list_slow_path(self, device, dtype, op): - torch.manual_seed(20240607) foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op) # 0-strides tensor1 = make_tensor((10, 10), dtype=dtype, device=device) @@ -799,7 +796,6 @@ def test_binary_op_list_slow_path(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_float_inf_nan(self, device, dtype, op): - torch.manual_seed(2024) inputs = ( [ torch.tensor([float("inf")], device=device, dtype=dtype), @@ -869,9 +865,6 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): "failing flakily on non sm86 cuda jobs", ) def test_binary_op_tensors_on_different_devices(self, device, dtype, op): - torch.manual_seed(202406) - # `tensors1`: ['cuda', 'cpu'] - # `tensors2`: ['cuda', 'cpu'] _cuda_tensors = next( iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True)) ).input From 5ffb032be682a34b959c82ce289b457ea6c6e504 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:26:38 +0000 Subject: [PATCH 166/171] Revert "Backward support for unbind() with NJT (#128032)" This reverts commit 5dc4f652bc5c068ef15130c955e3f2ffe11f4b74. Reverted https://github.com/pytorch/pytorch/pull/128032 on behalf of https://github.com/jbschlosser due to reverting to revert parent PR ([comment](https://github.com/pytorch/pytorch/pull/128032#issuecomment-2177296325)) --- test/test_nestedtensor.py | 19 ------------------- tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 17 ----------------- torch/csrc/autograd/FunctionsManual.h | 4 ---- torch/nested/_internal/ops.py | 11 ----------- 5 files changed, 1 insertion(+), 52 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6b9b8f3be45d56..86f58b5a0de3a0 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -5606,25 +5606,6 @@ def f(nt): for dynamic in [False, True, None]: self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - @dtypes(torch.float32, torch.double, torch.half) - def test_unbind_backward(self, device, dtype): - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 4, device=device), - torch.randn(5, 4, device=device), - torch.randn(3, 4, device=device), - ], - layout=torch.jagged, - requires_grad=True, - ) - - a, b, c = nt.unbind() - b.sum().backward() - - expected_grad = torch.zeros_like(nt) - expected_grad.unbind()[1].add_(1.0) - torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 02a3e6c518ad80..76a7a0a1e42a4f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2847,7 +2847,7 @@ self: unbind_backward(grads, dim) result: auto_linear AutogradNestedTensor: - self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options()) result: auto_linear - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f51c2f047f9351..9d897c667c906f 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1014,23 +1014,6 @@ Tensor unbind_backward_nested( return at::_nested_tensor_from_tensor_list(grads_tensors); } -Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim) { - TORCH_INTERNAL_ASSERT( - dim == 0, "unbind_backward_nested_jagged() only supports dim=0") - auto grad_nt = at::zeros_like(self); - auto unbound_grads = grad_nt.unbind(); - for (int64_t i : c10::irange(static_cast(grads.size()))) { - if (grads[i].defined()) { - unbound_grads[i].copy_(static_cast(grads[i])); - } - } - - return grad_nt; -} - Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { auto result = self; diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index ecf99bd098057b..dedff70be1ba34 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -244,10 +244,6 @@ at::Tensor unbind_backward_nested( const Tensor& nt_sizes, int64_t dim, const at::TensorOptions& options); -at::Tensor unbind_backward_nested_jagged( - const variable_list& grads, - const Tensor& self, - int64_t dim); at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes); at::Tensor unsqueeze_to( const at::Tensor& self, diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 8458f03717130c..6f1c47dd694712 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -472,17 +472,6 @@ def to_copy_default(func, *args, **kwargs): )(jagged_unary_pointwise) -@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all") -def zero__default(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - func(inp._values) - return inp - - @register_jagged_func( torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any" ) From b0d2fe6299c4462d28b23ef73d872eb608d73d96 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Jun 2024 00:28:53 +0000 Subject: [PATCH 167/171] Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)" This reverts commit 2a41fc03903de63270d325bd1886a50faf32d7e4. Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245)) --- aten/src/ATen/FunctionalInverses.cpp | 9 +- aten/src/ATen/native/native_functions.yaml | 14 +- test/dynamo/test_subclasses.py | 6 +- ...asDecompTest.test_has_decomposition.expect | 2 - test/test_nestedtensor.py | 173 +---------------- tools/autograd/derivatives.yaml | 4 +- torch/nested/_internal/nested_tensor.py | 174 ++++-------------- torch/nested/_internal/ops.py | 37 +--- torch/nested/_internal/sdpa.py | 62 ++----- 9 files changed, 69 insertions(+), 412 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index a1cf449cde7c7f..16b59333f918fb 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, return Tensor(); } -Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const c10::optional& min_seqlen, const c10::optional& max_seqlen) { +Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx) { auto values = at::_nested_get_values(mutated_view); if (inverse_return_mode != InverseReturnMode::NeverView) { return values; @@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const auto lengths = at::_nested_get_lengths(base); auto ragged_idx = at::_nested_get_ragged_idx(base); auto dummy = at::_nested_get_jagged_dummy(base); - auto min_seqlen = at::_nested_get_min_seqlen(base); - auto max_seqlen = at::_nested_get_max_seqlen(base); - auto nt = at::_nested_view_from_jagged( - mutated_view, offsets, dummy, lengths, ragged_idx, - (min_seqlen.defined() ? c10::optional(min_seqlen) : c10::nullopt), - (max_seqlen.defined() ? c10::optional(max_seqlen) : c10::nullopt)); + auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx); if (inverse_return_mode != InverseReturnMode::NeverView) { return nt; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b030141882c86e..a2d9095d56a380 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6185,12 +6185,12 @@ CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy autogen: _nested_view_from_buffer_copy.out -- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) variants: function device_check: NoCheck dispatch: {} -- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor +- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor variants: function device_check: NoCheck tags: view_copy @@ -6227,16 +6227,6 @@ device_check: NoCheck dispatch: {} -- func: _nested_get_min_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - -- func: _nested_get_max_seqlen(Tensor self) -> Tensor - variants: function - device_check: NoCheck - dispatch: {} - - func: _nested_get_jagged_dummy(Tensor any) -> Tensor category_override: dummy dispatch: {} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index f16ef15990fd8c..302b07e4ddb78b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -1616,15 +1616,15 @@ def backend(gm, args): guard_str, """\ Eq(s3 - 1, s0) -Eq(zf1, zf6)""", +Eq(zf1, zf4)""", ) else: self.assertExpectedInline( guard_str, """\ Eq(s4 - 1, s1) -Eq(s12 - 1, s7) -Eq(s11, s9)""", +Eq(s10 - 1, s5) +Eq(s9, s7)""", ) return gm diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 132d25a8b12f98..1179142e15d9e7 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -446,8 +446,6 @@ aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out aten::_nested_get_jagged_dummy aten::_nested_get_lengths -aten::_nested_get_max_seqlen -aten::_nested_get_min_seqlen aten::_nested_get_offsets aten::_nested_get_ragged_idx aten::_nested_get_values diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 86f58b5a0de3a0..78d082702aecb0 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -67,21 +67,6 @@ def _iter_constructors(): yield torch.nested.nested_tensor -# Returns True if the function recompiles between inputs1 and inputs2 with the -# specified dynamic setting. -def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): - compile_count = [0] - - def counter(gm, example_inputs): - compile_count[0] += 1 - return gm - - compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) - compiled_f(*inputs1) - compiled_f(*inputs2) - return compile_count[0] > 1 - - # Helper function to generate a pair of random nested tensors # one is contiguous, the other is not, but they appear to have same entries # an output nested tensor consists of @@ -4833,18 +4818,19 @@ def fn(values, same_size): check_results(fn, compiled_fn, generate_inp(20)) self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views + @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT @unittest.skipIf( TEST_WITH_ROCM, "ROCm doesn't support flash attention or mem_efficient attention for NT", ) - @dtypes( - *( - [torch.float16, torch.bfloat16, torch.float32] - if SM80OrLater - else [torch.float16, torch.float32] - ) + @parametrize( + "dtype", + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32], ) def test_sdpa(self, device, dtype): batch_size = 1 @@ -5187,6 +5173,8 @@ def test_sdpa_with_constant_sequence_length(self, device, dtype): ) self.assertEqual(output._values, output_dense) + # Doesn't work until we have real views + @xfailIfTorchDynamo @onlyCUDA @unittest.skipIf( not PLATFORM_SUPPORTS_FUSED_ATTENTION, @@ -5463,149 +5451,6 @@ def test_jagged_padded_dense_conversion_kernels(self, device, dtype): padded, [offsets_wrong], total_L ) - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_preserves_metadata_cache(self, device, dtype): - # shape (B, *, D) - nt = random_nt_from_dims( - [4, None, 3, 16], - device=device, - dtype=dtype, - layout=torch.jagged, - requires_grad=True, - ) - - # expect min / max seqlen to be stored here - cache = dict(nt._metadata_cache) - - @torch.compile - def f(nt): - q = nt.transpose(-3, -2) - output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) - return output - - output = f(nt) - output.backward(torch.ones_like(output)) - self.assertEqual(output._metadata_cache, cache) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_max_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_dynamic_min_seq_len(self, device, dtype): - # shape (B, *, D) - # min seq len: 7 - nt = torch.nested.nested_tensor( - [ - torch.randn(7, 5), - torch.randn(8, 5), - torch.randn(9, 5), - ], - layout=torch.jagged, - ) - - # min seq len: 8 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(8, 5), - torch.randn(9, 5), - torch.randn(10, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt) * nt._get_min_seqlen() - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - - @dtypes(torch.float32) - @skipIfTorchDynamo("Test compiles internally") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) - @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") - def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): - # shape (B, *, D) - # max seq len: 18 - nt = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(18, 5), - ], - layout=torch.jagged, - ) - - # max seq len: 19 - nt2 = torch.nested.nested_tensor( - [ - torch.randn(2, 5), - torch.randn(3, 5), - torch.randn(19, 5), - ], - layout=torch.jagged, - ) - - def f(nt): - nt2 = nt.sin() + 1 - # TODO: Replace with public API when we can use @properties - return torch.ones_like(nt2) * nt2._get_max_seqlen() - - ref = f(nt) - output = torch.compile(f, fullgraph=True, dynamic=False)(nt) - self.assertEqual(ref, output) - - for dynamic in [False, True, None]: - self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 76a7a0a1e42a4f..1e9b9091a20e94 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2794,14 +2794,14 @@ nested_size: non_differentiable nested_strides: non_differentiable -- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) self: grad.values() offsets: non_differentiable lengths: non_differentiable dummy: non_differentiable - name: _nested_get_values(Tensor(a) self) -> Tensor(a) - self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional(at::_nested_get_min_seqlen(self)) : c10::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional(at::_nested_get_max_seqlen(self)) : c10::nullopt)" + self: _nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self)) # Transformers - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 92423cf32b2fe8..66d25eacc7ad4b 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -27,15 +27,6 @@ def _get_sdpa_extreme_seqlen(func, tensor): return int(func(tensor).item()) -def _store_val_in_tensor(val) -> torch.Tensor: - # hack to get dynamic shapes support: store in a (val, 0) shaped tensor - return torch.zeros(val, 0) - - -def _load_val_from_tensor(t: torch.Tensor): - return t.shape[0] - - class NestedTensor(torch.Tensor): _values: torch.Tensor # type: ignore[assignment] _offsets: torch.Tensor @@ -131,14 +122,6 @@ def __init__(self, values, offsets, *, lengths=None, **kwargs): torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) - # min / max sequence length should be dynamic if present - max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None) - if max_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(max_seqlen_tensor, 0) - min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None) - if min_seqlen_tensor is not None: - torch._dynamo.mark_dynamic(min_seqlen_tensor, 0) - def values(self): # dispatch to get proper view relationship return torch._nested_get_values(self) # type: ignore[attr-defined] @@ -149,56 +132,25 @@ def offsets(self): def lengths(self): return self._lengths - # Private accessor functions for min / max sequence length. They're - # purposefully not @properties because those don't work with PT2 (yet). - # These compute / cache if not present. - # TODO: Revisit this when @properties are better supported by PT2. I think the ideal - # state would be to have public @properties for min / max sequence length that compile - # (including setters). - def _get_max_seqlen(self): - max_seqlen_tensor = self._max_seqlen_tensor - if max_seqlen_tensor is None: + @property + def _max_seqlen(self): + if "max_seqlen" not in self._metadata_cache: # compute & cache - max_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["max_seqlen"] = _get_sdpa_extreme_seqlen( torch.max, self._offsets.diff() if self._lengths is None else self._lengths, ) - max_seqlen_tensor = _store_val_in_tensor(max_val) - self._metadata_cache["max_seqlen"] = max_seqlen_tensor - return _load_val_from_tensor(max_seqlen_tensor) + return self._metadata_cache["max_seqlen"] - def _get_min_seqlen(self): - min_seqlen_tensor = self._min_seqlen_tensor - if min_seqlen_tensor is None: + @property + def _min_seqlen(self): + if "min_seqlen" not in self._metadata_cache: # compute & cache - min_val = _get_sdpa_extreme_seqlen( + self._metadata_cache["min_seqlen"] = _get_sdpa_extreme_seqlen( torch.min, self._offsets.diff() if self._lengths is None else self._lengths, ) - min_seqlen_tensor = _store_val_in_tensor(min_val) - self._metadata_cache["min_seqlen"] = min_seqlen_tensor - return _load_val_from_tensor(min_seqlen_tensor) - - # Private accessors used for treating min / max seqlen as inner tensors for - # flatten / unflatten. These must be properties to work with the traceable wrapper - # subclass logic. These do not compute / cache if not present. - @property - def _max_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("max_seqlen", None) - - @property - def _min_seqlen_tensor(self) -> Optional[torch.Tensor]: - return self._metadata_cache.get("min_seqlen", None) - - # These are old private @property accessors that are kept around for internal BC - # reasons. TODO: Remove these! - @property - def _max_seqlen(self): - return self._get_max_seqlen() - - @property - def _min_seqlen(self): - return self._get_min_seqlen() + return self._metadata_cache["min_seqlen"] def __repr__(self): # We should implement this in torch/_tensor_str.py instead @@ -218,7 +170,6 @@ def __reduce_ex__(self, proto): del state["_size"] del state["_strides"] - # TODO: Update this to handle the other inner tensors func = NestedTensor args = (self._values, self._offsets) return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state)) @@ -226,33 +177,22 @@ def __reduce_ex__(self, proto): def __tensor_flatten__(self): ctx = { "requires_grad": self.requires_grad, + # TODO: Don't guard on this! + "metadata_cache": self._metadata_cache, "ragged_idx": self._ragged_idx, } inner_tensors = ["_values", "_offsets"] if self._lengths is not None: inner_tensors.append("_lengths") - if self._min_seqlen_tensor is not None: - inner_tensors.append("_min_seqlen_tensor") - if self._max_seqlen_tensor is not None: - inner_tensors.append("_max_seqlen_tensor") return inner_tensors, ctx @staticmethod def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): - # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] - assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5 + # inner tensors: _values, _offsets, [_lengths] + assert len(inner_tensors) >= 2 and len(inner_tensors) <= 3 values = inner_tensors["_values"] offsets = inner_tensors["_offsets"] lengths = inner_tensors.get("_lengths", None) - min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None) - max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None) - - metadata_cache = {} - if min_seqlen_tensor is not None: - metadata_cache["min_seqlen"] = min_seqlen_tensor - if max_seqlen_tensor is not None: - metadata_cache["max_seqlen"] = max_seqlen_tensor - ragged_idx = meta["ragged_idx"] # Note that we cannot simply check if is_fake(values) because @@ -271,7 +211,7 @@ def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): lengths=lengths, requires_grad=meta["requires_grad"], _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, + _metadata_cache=meta["metadata_cache"], ) @classmethod @@ -336,15 +276,6 @@ def forward( offsets: torch.Tensor, metadata_cache: Optional[Dict[str, Any]] = None, ): # type: ignore[override] - # maintain BC with this usages of this where the seqlens are stuffed - # directly into the metadata cache as non-Tensors / ints - if metadata_cache is not None: - min_seqlen = metadata_cache.get("min_seqlen", None) - max_seqlen = metadata_cache.get("max_seqlen", None) - if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor): - metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen) - if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor): - metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen) return NestedTensor( values.detach(), offsets=offsets, @@ -412,12 +343,12 @@ def jagged_from_list( ] ) - # compute this now since it's easy - min_seqlen = min([t.shape[0] for t in tensors]) - max_seqlen = max([t.shape[0] for t in tensors]) - ret_nt = nested_view_from_values_offsets( - values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen - ) + ret_nt = nested_view_from_values_offsets(values, offsets) + ret_nt._metadata_cache = { + # compute this now since it's easy + "max_seqlen": max(t.shape[0] for t in tensors), + "min_seqlen": min(t.shape[0] for t in tensors), + } return (ret_nt, offsets) # type: ignore[return-value] @@ -474,19 +405,16 @@ def jagged_from_tensor_and_lengths( if is_contiguous: ret_nt = nested_view_from_values_offsets( - values[offsets[0] : offsets[-1]], - offsets - offsets[0], - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, + values[offsets[0] : offsets[-1]], offsets - offsets[0] ) else: - ret_nt = nested_view_from_values_offsets_lengths( - values, - offsets, - length_list, - min_seqlen=min_seqlen, - max_seqlen=actual_max_seqlen, - ) + ret_nt = nested_view_from_values_offsets_lengths(values, offsets, length_list) + + # populate metadata cache with computed seqlen extremes + ret_nt._metadata_cache = { + "max_seqlen": actual_max_seqlen, + "min_seqlen": min_seqlen, + } return (ret_nt, offsets, None if is_contiguous else length_list) @@ -508,45 +436,13 @@ def _nt_view_dummy() -> torch.Tensor: return _dummy_instance -def nested_view_from_values_offsets( - values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) - - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) - +def nested_view_from_values_offsets(values, offsets, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - None, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] - - -def nested_view_from_values_offsets_lengths( - values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None -): - min_seqlen_tensor = None - if min_seqlen is not None: - min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + values, offsets, _nt_view_dummy(), None, ragged_idx + ) - max_seqlen_tensor = None - if max_seqlen is not None: - max_seqlen_tensor = _store_val_in_tensor(max_seqlen) +def nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=1): return torch._nested_view_from_jagged( # type: ignore[attr-defined] - values, - offsets, - _nt_view_dummy(), - lengths, - ragged_idx, - min_seqlen_tensor, - max_seqlen_tensor, - ) # type: ignore[return-value] + values, offsets, _nt_view_dummy(), lengths, ragged_idx + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 6f1c47dd694712..6ec3ba538f9772 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -1088,7 +1088,7 @@ def values_default(func, *args, **kwargs): @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, - "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", + "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?", ) def _nested_view_from_jagged_default(func, *args, **kwargs): _, new_kwargs = normalize_function( @@ -1101,21 +1101,8 @@ def _nested_view_from_jagged_default(func, *args, **kwargs): new_kwargs["lengths"], ) ragged_idx = new_kwargs["ragged_idx"] - min_seqlen = new_kwargs["min_seqlen"] - max_seqlen = new_kwargs["max_seqlen"] - metadata_cache = {} - if min_seqlen is not None: - metadata_cache["min_seqlen"] = min_seqlen - if max_seqlen is not None: - metadata_cache["max_seqlen"] = max_seqlen - return NestedTensor( - values, - offsets, - lengths=lengths, - _ragged_idx=ragged_idx, - _metadata_cache=metadata_cache, - ) + return NestedTensor(values, offsets, lengths=lengths, _ragged_idx=ragged_idx) @register_jagged_func(torch.ops.aten._nested_get_offsets.default, "self: jt_all") @@ -1148,26 +1135,6 @@ def _nested_get_ragged_idx(func, *args, **kwargs): return inp._ragged_idx -@register_jagged_func(torch.ops.aten._nested_get_min_seqlen.default, "self: jt_all") -def _nested_get_min_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("min_seqlen", None) - - -@register_jagged_func(torch.ops.aten._nested_get_max_seqlen.default, "self: jt_all") -def _nested_get_max_seqlen(func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - inp = new_kwargs.pop("input") - return inp._metadata_cache.get("max_seqlen", None) - - # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 8f2eba4db3e463..b7c69c905e9a86 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -15,7 +15,7 @@ ) from torch.nn.attention import SDPBackend -from .nested_tensor import NestedTensor +from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer log = logging.getLogger(__name__) @@ -125,7 +125,7 @@ def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( return False # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] - if param._get_min_seqlen() == 0: + if param._min_seqlen == 0: if debug: log.warning( "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", @@ -315,7 +315,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in if qkv.lengths() is None: # TODO: Explore performance impact of copying cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen n_elem = qkv.values().shape[0] else: # TODO: Explore performance impact of copying @@ -323,7 +323,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) ) batch_size = qkv.size(0) - max_seqlen = qkv._get_max_seqlen() + max_seqlen = qkv._max_seqlen # TODO: Explore performance impact when compiling n_elem = int(cumulative_seqlen[-1].item()) return cumulative_seqlen, max_seqlen, n_elem @@ -364,7 +364,7 @@ def _view_as_dense( tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int ) -> torch.Tensor: if tensor.is_nested: - return tensor.values() + return buffer_from_jagged(tensor) return tensor.view(Nnz, num_heads, head_dim) @@ -567,8 +567,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._get_max_seqlen(), - "_min_seqlen": q_t._get_min_seqlen(), + "_max_seqlen": q_t._max_seqlen, + "_min_seqlen": q_t._min_seqlen, } return ( @@ -694,14 +694,9 @@ def jagged_scaled_dot_product_attention( False, scale=og_scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - attention = nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + attention = ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -737,14 +732,9 @@ def jagged_scaled_dot_product_attention( scale=scale, ) - from torch.nested._internal.nested_tensor import nested_view_from_values_offsets - # Reshape output to convert nnz to batch_size and seq_len - return nested_view_from_values_offsets( - attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + return ViewNestedFromBuffer.apply( + attention.squeeze(0), output_nt_info["offsets"] ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output @@ -754,19 +744,12 @@ def jagged_scaled_dot_product_attention( d1 = query._size[1] d2 = value._size[-1] - min_seqlen_tensor = query._metadata_cache.get( - "min_seqlen", None - ) # type: ignore[attr-defined] - max_seqlen_tensor = query._metadata_cache.get( - "max_seqlen", None - ) # type: ignore[attr-defined] - # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] transpose = torch.transpose(jagged_layout_nt, 1, 2) - tensor_list = transpose.values().split(list(lengths), dim=0) + tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) strided_nt = strided_nt.transpose(1, 2).contiguous() return strided_nt @@ -779,28 +762,11 @@ def get_strided_layout_nested_tensor(jagged_layout_nt): query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] - from torch.nested._internal.nested_tensor import ( - _load_val_from_tensor, - nested_view_from_values_offsets, - ) - # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) - attn_out = nested_view_from_values_offsets( - attn_out, - offsets, - min_seqlen=( - None - if min_seqlen_tensor is None - else _load_val_from_tensor(min_seqlen_tensor) - ), - max_seqlen=( - None - if max_seqlen_tensor is None - else _load_val_from_tensor(max_seqlen_tensor) - ), - ).transpose(1, 2) + attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) + attn_out = attn_out.transpose(1, 2) return attn_out else: From 2458f79f83e865a0469f844e87a64edfcecc7065 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Mon, 17 Jun 2024 12:40:38 -0700 Subject: [PATCH 168/171] [Inductor UT][Intel GPU] Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU (#128881) Skip newly added test case test_torchinductor_strided_blocks:test_reduction for Intel GPU because it have not implemented reduction kernel split. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128881 Approved by: https://github.com/blaine-rister, https://github.com/EikanWang, https://github.com/malfet --- test/inductor/test_torchinductor_strided_blocks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index bd859802892df7..bf96ad8d486d8c 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, @@ -214,6 +215,7 @@ def get_input(view_size: Tuple[int]) -> torch.Tensor: # Expect 3 block pointers: 2 inputs one output self.run_and_compare(foo, x, y, expected_num_block_pointers=3) + @skipIfXpu @parametrize( "view_size,num_block_pointers,num_triton_kernels", [ From eda375a49078f5fecc90f28ca8ff949e8e5811e9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 17 Jun 2024 19:54:34 -0700 Subject: [PATCH 169/171] [Inductor] Remove min/max from inductor opinfo test (#128925) **Summary** Remove `max.binary, min.binary, maximum, minimum` from `inductor_one_sample` op list as we fix the bool vectorization issue in https://github.com/pytorch/pytorch/pull/126841. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_maximum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_minimum python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_min_binary python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_max_binary ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128925 Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_torchinductor_opinfo.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 29be591dc006c1..c7153b5b6d8491 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -425,11 +425,7 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "logspace": {f16}, "logspace.tensor_overload": {f16, f32, f64, i32, i64}, "masked_logsumexp": {i64}, - "max.binary": {b8}, "max_pool2d_with_indices_backward": {f16, f32, f64}, - "maximum": {b8}, - "min.binary": {b8}, - "minimum": {b8}, "new_empty_strided": {f16}, "nn.functional.adaptive_avg_pool3d": {f16}, "nn.functional.adaptive_max_pool1d": {f16, f32}, From 4bc90185fb77438717d59b2d9bb63096ae682935 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Wed, 19 Jun 2024 01:17:05 +0000 Subject: [PATCH 170/171] fix: Print statements causing parse error (#128969) The print statements for the get_workflow_type script is problematic because the shell script calling this script is expecting the output to only be JSON. This PR resolves this by removing all print statements to covert them to a message field in the JSON return output so that the output can continue to expect to be JSON while giving us the debug data we are looking for. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128969 Approved by: https://github.com/tylertitsworth, https://github.com/ZainRizvi --- .github/scripts/get_workflow_type.py | 47 ++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/.github/scripts/get_workflow_type.py b/.github/scripts/get_workflow_type.py index 4a5303ae9212fb..5384ef92c12f28 100644 --- a/.github/scripts/get_workflow_type.py +++ b/.github/scripts/get_workflow_type.py @@ -1,6 +1,6 @@ import json from argparse import ArgumentParser -from typing import Any +from typing import Any, Tuple from github import Auth, Github from github.Issue import Issue @@ -9,6 +9,8 @@ WORKFLOW_LABEL_META = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation LABEL_TYPE_KEY = "label_type" +MESSAGE_KEY = "message" +MESSAGE = "" # Debug message to return to the caller def parse_args() -> Any: @@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool: return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_workflow_type(issue: Issue, username: str) -> str: +def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]: try: user_list = issue.get_comments()[0].body.split() if user_list[0] == "!": - print("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = "LF Workflows are disabled for everyone. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE elif user_list[0] == "*": - print("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = "LF Workflows are enabled for everyone. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE elif username in user_list: - print(f"LF Workflows are enabled for {username}. Using LF runners.") - return WORKFLOW_LABEL_LF + MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners." + return WORKFLOW_LABEL_LF, MESSAGE else: - print(f"LF Workflows are disabled for {username}. Using meta runners.") - return WORKFLOW_LABEL_META + MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners." + return WORKFLOW_LABEL_META, MESSAGE except Exception as e: - print( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" + return WORKFLOW_LABEL_META, MESSAGE def main() -> None: args = parse_args() if is_exception_branch(args.github_branch): - print(f"Exception branch: '{args.github_branch}', using meta runners") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners", + } else: try: gh = get_gh_client(args.github_token) # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 issue = get_issue(gh, args.github_repo, args.github_issue) - - output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)} + label_type, message = get_workflow_type(issue, args.github_user) + output = { + LABEL_TYPE_KEY: label_type, + MESSAGE_KEY: message, + } except Exception as e: - print(f"Failed to get issue. Falling back to meta runners. Exception: {e}") - output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META} + output = { + LABEL_TYPE_KEY: WORKFLOW_LABEL_META, + MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}", + } json_output = json.dumps(output) print(json_output) From df85f34a14dd30f784418624b05bd52b12ab8b0b Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 14 Jun 2024 01:51:17 -0700 Subject: [PATCH 171/171] Add test to xfail_list only for abi_compatible (#128506) https://github.com/pytorch/pytorch/pull/126717 will skip the tests in both ABI compatible and non-ABI compatible mode. It's not expected to skip them in non-ABI compatible mode since they can actually run successfully in such mode but only have issues in ABI compatible mode. We leverage the existing `xfail_list` for those that will only fail in ABI compatible mode. - `test_qlinear_add` is already in the `xfail_list`. - `test_linear_packed` doesn't fail either in my local run (running with `TORCHINDUCTOR_ABI_COMPATIBLE=1`) or in the CI of this PR so I didn't add it into `xfail_list`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128506 Approved by: https://github.com/jgong5, https://github.com/desertfire --- test/inductor/test_cpu_cpp_wrapper.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8bf9b1e6a61f8e..0a2b75ddb55441 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -95,7 +95,9 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): "test_qconv2d_relu_cpu", "test_qlinear_cpu", "test_qlinear_add_cpu", + "test_qlinear_add_relu_cpu", "test_qlinear_dequant_promotion_cpu", + "test_qlinear_gelu_cpu", "test_qlinear_relu_cpu", ] for test_name in xfail_list: @@ -125,7 +127,6 @@ def make_test_case( slow=False, func_inputs=None, code_string_count=None, - skip=None, ): test_name = f"{name}_{device}" if device else name if code_string_count is None: @@ -134,8 +135,6 @@ def make_test_case( func = getattr(tests, test_name) assert callable(func), "not a callable" func = slowTest(func) if slow else func - if skip: - func = unittest.skip(skip)(func) @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): @@ -183,7 +182,6 @@ class BaseTest(NamedTuple): slow: bool = False func_inputs: list = None code_string_count: dict = {} - skip: str = None for item in [ BaseTest("test_add_complex"), @@ -242,9 +240,7 @@ class BaseTest(NamedTuple): torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), - BaseTest( - "test_linear_packed", "", test_cpu_repro.CPUReproTests(), skip="Failing" - ), + BaseTest("test_linear_packed", "", test_cpu_repro.CPUReproTests()), BaseTest( "test_lstm_packed_change_input_sizes", "cpu", @@ -318,21 +314,18 @@ class BaseTest(NamedTuple): "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_add_relu", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), - skip="Failing", ), BaseTest( "test_qlinear_dequant_promotion", @@ -388,7 +381,6 @@ class BaseTest(NamedTuple): item.slow, item.func_inputs, item.code_string_count, - skip=item.skip, ) test_torchinductor.copy_tests(