diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt
index 9182b03d38..2c2d910da9 100644
--- a/.ci/docker/ci_commit_pins/pytorch.txt
+++ b/.ci/docker/ci_commit_pins/pytorch.txt
@@ -1 +1 @@
-0a94bb432ed75cc2d950d81b2921363218a7e459
+27e35de6c288bffad1b4d18b393579c1d1a95547
diff --git a/.ci/docker/conda-env-ci.txt b/.ci/docker/conda-env-ci.txt
index 8f2e65dae7..c675b3d9f6 100644
--- a/.ci/docker/conda-env-ci.txt
+++ b/.ci/docker/conda-env-ci.txt
@@ -1,4 +1,5 @@
cmake=3.22.1
ninja=1.10.2
libuv
+llvm-openmp
pkg-config
diff --git a/.ci/scripts/setup-macos.sh b/.ci/scripts/setup-macos.sh
index 395f0c1767..75f999af41 100755
--- a/.ci/scripts/setup-macos.sh
+++ b/.ci/scripts/setup-macos.sh
@@ -121,6 +121,7 @@ setup_macos_env_variables
# NB: we need buck2 in all cases because cmake build also depends on calling
# buck2 atm
install_buck
+brew install libomp
install_pip_dependencies
# TODO(huydhn): Unlike our self-hosted runner, GitHub runner doesn't have access
diff --git a/.ci/scripts/test_eval_llama_mmlu.sh b/.ci/scripts/test_eval_llama_mmlu.sh
index c3c0a3d1a6..2f4cf1b3b3 100644
--- a/.ci/scripts/test_eval_llama_mmlu.sh
+++ b/.ci/scripts/test_eval_llama_mmlu.sh
@@ -43,6 +43,7 @@ run_and_verify() {
--tasks mmlu \
-f 5 \
--max_seq_length 2048 \
+ --max_context_length 2048 \
--limit 5 > result.txt
# Verify result.txt
diff --git a/.ci/scripts/test_eval_llama_wikitext.sh b/.ci/scripts/test_eval_llama_wikitext.sh
index 77af12270c..8c1713ae12 100644
--- a/.ci/scripts/test_eval_llama_wikitext.sh
+++ b/.ci/scripts/test_eval_llama_wikitext.sh
@@ -41,6 +41,7 @@ run_and_verify() {
-kv \
-d fp32 \
--max_seq_length 2048 \
+ --max_context_length 2048 \
--limit 5 > result.txt
# Verify result.txt
diff --git a/.github/workflows/_android.yml b/.github/workflows/_android.yml
index 96fdfd51fe..36b679eda4 100644
--- a/.github/workflows/_android.yml
+++ b/.github/workflows/_android.yml
@@ -7,7 +7,10 @@ on:
jobs:
build-llm-demo:
name: build-llm-demo
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-clang12-android
diff --git a/.github/workflows/_unittest.yml b/.github/workflows/_unittest.yml
index 74ea5ca7bc..414f86494b 100644
--- a/.github/workflows/_unittest.yml
+++ b/.github/workflows/_unittest.yml
@@ -14,7 +14,10 @@ on:
jobs:
linux:
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: linux.2xlarge
docker-image: ${{ inputs.docker-image }}
diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml
index 5d34bd8626..a83d374ab0 100644
--- a/.github/workflows/android-perf.yml
+++ b/.github/workflows/android-perf.yml
@@ -155,7 +155,10 @@ jobs:
export-models:
name: export-models
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
needs: set-parameters
secrets: inherit
strategy:
@@ -332,7 +335,10 @@ jobs:
build-benchmark-app:
name: build-benchmark-app
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
needs: set-parameters
with:
runner: linux.2xlarge
diff --git a/.github/workflows/android-release-artifacts.yml b/.github/workflows/android-release-artifacts.yml
index a10de79363..d204e121ff 100644
--- a/.github/workflows/android-release-artifacts.yml
+++ b/.github/workflows/android-release-artifacts.yml
@@ -31,7 +31,10 @@ jobs:
build-aar:
name: build-aar
needs: check-if-aar-exists
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-clang12-android
diff --git a/.github/workflows/apple.yml b/.github/workflows/apple.yml
index 8ac755bf5d..8349ddb419 100644
--- a/.github/workflows/apple.yml
+++ b/.github/workflows/apple.yml
@@ -37,7 +37,7 @@ jobs:
id: set_version
shell: bash
run: |
- VERSION="0.4.0.$(TZ='PST8PDT' date +%Y%m%d)"
+ VERSION="0.5.0.$(TZ='PST8PDT' date +%Y%m%d)"
echo "version=$VERSION" >> "$GITHUB_OUTPUT"
build-demo-ios:
diff --git a/.github/workflows/doc-build.yml b/.github/workflows/doc-build.yml
index 7a3b862b21..8d9081615b 100644
--- a/.github/workflows/doc-build.yml
+++ b/.github/workflows/doc-build.yml
@@ -15,7 +15,10 @@ on:
jobs:
build:
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
include:
@@ -81,8 +84,9 @@ jobs:
needs: build
if: github.repository == 'pytorch/executorch' && github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v'))
permissions:
+ id-token: write
contents: write
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/executorch
download-artifact: docs
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 93c89355d7..aab68b3059 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -16,7 +16,10 @@ concurrency:
jobs:
lintrunner:
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-linter
@@ -62,7 +65,10 @@ jobs:
exit $RC
android-java-format:
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
with:
runner: linux.2xlarge
docker-image: executorch-ubuntu-22.04-linter
diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml
index df13140ca9..6b4644bb52 100644
--- a/.github/workflows/periodic.yml
+++ b/.github/workflows/periodic.yml
@@ -39,7 +39,10 @@ jobs:
test-models-linux:
name: test-models-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
needs: gather-models
strategy:
matrix: ${{ fromJSON(needs.gather-models.outputs.models) }}
diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml
index dbe0e872ac..16611c09f3 100644
--- a/.github/workflows/pull.yml
+++ b/.github/workflows/pull.yml
@@ -33,7 +33,10 @@ jobs:
test-setup-linux-gcc:
name: test-setup-linux-gcc
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -55,7 +58,10 @@ jobs:
test-models-linux:
name: test-models-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
needs: gather-models
strategy:
matrix: ${{ fromJSON(needs.gather-models.outputs.models) }}
@@ -82,7 +88,10 @@ jobs:
test-llama-runner-linux:
name: test-llama-runner-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
dtype: [fp32]
@@ -121,7 +130,10 @@ jobs:
test-llama-runner-linux-android:
name: test-llama-runner-linux-android
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -141,7 +153,10 @@ jobs:
test-custom-ops-linux:
name: test-custom-ops-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -162,7 +177,10 @@ jobs:
test-selective-build-linux:
name: test-selective-build-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -183,7 +201,10 @@ jobs:
test-llava-runner-linux:
name: test-llava-runner-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -191,7 +212,7 @@ jobs:
docker-image: executorch-ubuntu-22.04-clang12
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
- timeout: 90
+ timeout: 180
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
@@ -200,7 +221,7 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
- bash install_executorch.sh --pybind xnnpack
+ bash install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
# install Llava requirements
bash examples/models/llama/install_requirements.sh
@@ -214,7 +235,10 @@ jobs:
test-quantized-aot-lib-linux:
name: test-quantized-aot-lib-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -234,7 +258,10 @@ jobs:
test-pybind-build-linux:
name: test-pybind-build-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -260,7 +287,10 @@ jobs:
test-binary-size-linux-gcc:
name: test-binary-size-linux-gcc
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -292,7 +322,10 @@ jobs:
test-binary-size-linux:
name: test-binary-size-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -324,10 +357,16 @@ jobs:
android:
uses: ./.github/workflows/_android.yml
+ permissions:
+ id-token: write
+ contents: read
needs: test-llama-runner-linux
unittest:
uses: ./.github/workflows/_unittest.yml
+ permissions:
+ id-token: write
+ contents: read
with:
docker-image: executorch-ubuntu-22.04-clang12
@@ -365,7 +404,10 @@ jobs:
test-llama-runner-qnn-linux:
name: test-llama-runner-qnn-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
dtype: [fp32]
@@ -400,7 +442,10 @@ jobs:
test-qnn-models-linux:
name: test-qnn-models-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -419,7 +464,10 @@ jobs:
test-phi-3-mini-runner-linux:
name: test-phi-3-mini-runner-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -436,7 +484,7 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
- bash install_executorch.sh --pybind xnnpack
+ bash install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
# install phi-3-mini requirements
bash examples/models/phi-3-mini/install_requirements.sh
@@ -446,7 +494,10 @@ jobs:
test-eval_llama-wikitext-linux:
name: test-eval_llama-wikitext-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -463,7 +514,7 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
- bash install_executorch.sh --pybind xnnpack
+ bash install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
# install llama requirements
bash examples/models/llama/install_requirements.sh
@@ -473,7 +524,10 @@ jobs:
test-eval_llama-mmlu-linux:
name: test-eval_llama-mmlu-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -481,7 +535,7 @@ jobs:
docker-image: executorch-ubuntu-22.04-clang12
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
- timeout: 90
+ timeout: 180
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
@@ -490,7 +544,7 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
- bash install_executorch.sh --pybind xnnpack
+ bash install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
# install llama requirements
bash examples/models/llama/install_requirements.sh
@@ -500,7 +554,10 @@ jobs:
test-llama_runner_eager-linux:
name: test-llama_runner_eager-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
@@ -517,7 +574,7 @@ jobs:
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
- bash install_executorch.sh --pybind xnnpack
+ bash install_executorch.sh --pybind xnnpack --use-pt-pinned-commit
# install llama requirements
bash examples/models/llama/install_requirements.sh
@@ -527,7 +584,10 @@ jobs:
test-mediatek-models-linux:
name: test-mediatek-models-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
fail-fast: false
with:
diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml
index 0cbbe6f643..04a6c96f3e 100644
--- a/.github/workflows/trunk.yml
+++ b/.github/workflows/trunk.yml
@@ -107,7 +107,10 @@ jobs:
test-demo-backend-delegation:
name: test-demo-backend-delegation
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
include:
@@ -147,7 +150,7 @@ jobs:
conda activate "${CONDA_ENV}"
source .ci/scripts/utils.sh
- install_executorch
+ install_executorch "use-pt-pinned-commit"
.ci/scripts/setup-arm-baremetal-tools.sh
@@ -177,7 +180,7 @@ jobs:
conda activate "${CONDA_ENV}"
source .ci/scripts/utils.sh
- install_executorch
+ install_executorch "use-pt-pinned-commit"
.ci/scripts/setup-arm-baremetal-tools.sh
@@ -301,7 +304,10 @@ jobs:
test-qnn-model:
name: test-qnn-model
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
dtype: [fp32]
@@ -361,7 +367,10 @@ jobs:
# NB: Don't run this on fork PRs because they won't have access to the secret and would fail anyway
if: ${{ !github.event.pull_request.head.repo.fork }}
name: test-huggingface-transformers
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
secrets: inherit
strategy:
matrix:
@@ -445,7 +454,10 @@ jobs:
test-llama-runner-qnn-linux:
name: test-llama-runner-qnn-linux
- uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
+ uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
+ permissions:
+ id-token: write
+ contents: read
strategy:
matrix:
dtype: [fp32]
diff --git a/.lintrunner.toml b/.lintrunner.toml
index dd75ea8f32..093f9cdbcb 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -1,4 +1,4 @@
-merge_base_with = "origin/main"
+merge_base_with = "main"
[[linter]]
code = 'FLAKE8'
@@ -291,6 +291,7 @@ code = 'MYPY'
include_patterns = [
# TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders.
# 'backends/**/*.py',
+ 'backends/arm/**/*.py',
'build/**/*.py',
'codegen/**/*.py',
# 'devtools/**/*.py',
@@ -312,6 +313,7 @@ exclude_patterns = [
'**/third-party/**',
'scripts/check_binary_dependencies.py',
'profiler/test/test_profiler_e2e.py',
+ 'backends/arm/test/**',
]
command = [
'python',
diff --git a/.mypy.ini b/.mypy.ini
index 43d75e64de..8c1c9dbcad 100644
--- a/.mypy.ini
+++ b/.mypy.ini
@@ -77,6 +77,9 @@ ignore_missing_imports = True
[mypy-ruamel]
ignore_missing_imports = True
+[mypy-serializer.*]
+ignore_missing_imports = True
+
[mypy-setuptools.*]
ignore_missing_imports = True
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4c1b8e2ec7..ca8d1bbbcf 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -239,6 +240,13 @@ cmake_dependent_option(
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
)
+
+if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
+ set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
+ set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
+ set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
+endif()
+
if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
set(EXECUTORCH_BUILD_KERNELS_CUSTOM ON)
@@ -790,6 +798,35 @@ if(EXECUTORCH_BUILD_PYBIND)
install(TARGETS portable_lib
LIBRARY DESTINATION executorch/extension/pybindings
)
+
+ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
+
+ set(_pybind_training_dep_libs
+ ${TORCH_PYTHON_LIBRARY}
+ etdump
+ executorch
+ util
+ torch
+ extension_training
+ )
+
+ if(EXECUTORCH_BUILD_XNNPACK)
+ # need to explicitly specify XNNPACK and microkernels-prod
+ # here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
+ list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
+ endif()
+
+ # pybind training
+ pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
+
+ target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS})
+ target_compile_options(_training_lib PUBLIC ${_pybind_compile_options})
+ target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
+
+ install(TARGETS _training_lib
+ LIBRARY DESTINATION executorch/extension/training/pybindings
+ )
+ endif()
endif()
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
@@ -819,6 +856,14 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
list(APPEND _executor_runner_libs quantized_ops_lib)
endif()
+ if(EXECUTORCH_ENABLE_EVENT_TRACER)
+ if(EXECUTORCH_BUILD_DEVTOOLS)
+ list(APPEND _executor_runner_libs etdump flatccrt)
+ else()
+ message(SEND_ERROR "Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled.")
+ endif()
+ endif()
+
add_executable(executor_runner ${_executor_runner__srcs})
if(CMAKE_BUILD_TYPE STREQUAL "Release")
if(APPLE)
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index bd943c587b..88f55ef73c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -44,6 +44,38 @@ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
+### Issue Labels
+
+#### Module/Partner Labels
+
+[Labels beginning with `module:`](https://github.com/pytorch/executorch/labels?q=%22module%3A+%22)
+indicate the area that the issue relates to. The ExecuTorch oncall will
+typically add this label.
+
+[Labels beginning with `partner:`](https://github.com/pytorch/executorch/labels?q=%22partner%3A+%22)
+indicate the ExecuTorch partner who owns the issue. The ExecuTorch oncall will
+typically add this label.
+
+#### Lifecycle Labels
+
+The ExecuTorch oncall will triage new issues. If the issue requires more
+information from the issue's author, oncall will add the `need-user-input` label
+and wait for the author to respond.
+
+Once the issue contains enough information, the oncall will:
+- Ensure that the title is descriptive
+- Add one of the labels:
+ - `bug`: The issue describes an unexpected problem
+ - `feature`: The issue describes a request for new functionality
+ - `rfc`: The issue describes a proposed change to functionality
+- Add one `module:` label or one `partner:` label, as described above
+- Add the `triaged` label
+
+After this point, the oncall has finished the triage process, and the
+module owner or partner is responsible for resolving the issue. (See
+https://github.com/pytorch/executorch/issues/7679 for the mapping of labels to
+owners.)
+
### Claiming Issues
We'd love your help closing out [open
issues](https://github.com/pytorch/executorch/issues?q=sort%3Aupdated-desc+is%3Aissue+is%3Aopen)
diff --git a/README-wheel.md b/README-wheel.md
index e04e6dfa6d..9f074ab5ee 100644
--- a/README-wheel.md
+++ b/README-wheel.md
@@ -4,20 +4,21 @@ standard on-device iOS and Android mobile deployments. One of the main goals for
ExecuTorch is to enable wider customization and deployment capabilities of the
PyTorch programs.
-The `executorch` pip package is in alpha.
-* Supported python versions: 3.10, 3.11
+The `executorch` pip package is in beta.
+* Supported python versions: 3.10, 3.11, 3.12
* Compatible systems: Linux x86_64, macOS aarch64
-The prebuilt `executorch.extension.pybindings.portable_lib` module included in
-this package provides a way to run ExecuTorch `.pte` files, with some
-restrictions:
+The prebuilt `executorch.runtime` module included in this package provides a way
+to run ExecuTorch `.pte` files, with some restrictions:
* Only [core ATen
operators](https://pytorch.org/executorch/stable/ir-ops-set-definition.html)
are linked into the prebuilt module
* Only the [XNNPACK backend
delegate](https://pytorch.org/executorch/main/native-delegates-executorch-xnnpack-delegate.html)
- is linked into the prebuilt module
-* [macOS only] [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html) and [MPS](https://pytorch.org/executorch/main/build-run-mps.html) backend delegates are linked into the prebuilt module.
+ is linked into the prebuilt module.
+* \[macOS only] [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html)
+ and [MPS](https://pytorch.org/executorch/main/build-run-mps.html) backend
+ delegates are also linked into the prebuilt module.
Please visit the [ExecuTorch website](https://pytorch.org/executorch/) for
tutorials and documentation. Here are some starting points:
diff --git a/README.md b/README.md
index aded66bf40..3a2a833e05 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,37 @@
-# ExecuTorch
-
-**ExecuTorch** is an end-to-end solution for enabling on-device inference
-capabilities across mobile and edge devices including wearables, embedded
-devices and microcontrollers. It is part of the PyTorch Edge ecosystem and
-enables efficient deployment of PyTorch models to edge devices.
+
+

+
ExecuTorch: A powerful on-device AI Framework
+
+
+
+
+
+**ExecuTorch** is an end-to-end solution for on-device inference and training. It powers much of Meta's on-device AI experiences across Facebook, Instagram, Meta Quest, Ray-Ban Meta Smart Glasses, WhatsApp, and more.
+
+It supports a wide range of models including LLMs (Large Language Models), CV (Computer Vision), ASR (Automatic Speech Recognition), and TTS (Text to Speech).
+
+Platform Support:
+- Operating Systems:
+ - iOS
+ - Mac
+ - Android
+ - Linux
+ - Microcontrollers
+
+- Hardware Acceleration:
+ - Apple
+ - Arm
+ - Cadence
+ - MediaTek
+ - Qualcomm
+ - Vulkan
+ - XNNPACK
Key value propositions of ExecuTorch are:
@@ -17,35 +45,21 @@ Key value propositions of ExecuTorch are:
experience due to a lightweight runtime and utilizing full hardware
capabilities such as CPUs, NPUs, and DSPs.
-For a comprehensive technical overview of ExecuTorch and step-by-step tutorials,
-please visit our documentation website [for the latest release](https://pytorch.org/executorch/stable/index.html) (or the [main branch](https://pytorch.org/executorch/main/index.html)).
-
-Check out the [Getting Started](https://pytorch.org/executorch/stable/getting-started-setup.html#quick-setup-colab-jupyter-notebook-prototype) page for a quick spin.
-
-Check out the examples of [Llama](./examples/models/llama/README.md), [Llava](./examples/models/llava/README.md) and [other models](./examples/README.md) running on edge devices using ExecuTorch.
+## Getting Started
+To get started you can:
+- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/main/index.html) on getting things running locally and deploy a model to a device
+- Use this [Colab Notebook](https://pytorch.org/executorch/stable/getting-started-setup.html#quick-setup-colab-jupyter-notebook-prototype) to start playing around right away
+- Jump straight into LLMs use cases by following specific instructions for [Llama](./examples/models/llama/README.md) and [Llava](./examples/models/llava/README.md)
-**[UPDATE - 10/24]** We have added support for running [Llama 3.2 Quantized 1B/3B](./examples/models/llama/README.md) models via ExecuTorch.
-
-## Feedback
+## Feedback and Engagement
We welcome any feedback, suggestions, and bug reports from the community to help
-us improve our technology. Please use the [PyTorch
-Forums](https://discuss.pytorch.org/c/executorch) for discussion and feedback
-about ExecuTorch using the **ExecuTorch** category, and our [GitHub
-repository](https://github.com/pytorch/executorch/issues) for bug reporting.
-
-We recommend using the latest release tag from the
-[Releases](https://github.com/pytorch/executorch/releases) page when developing.
+us improve our technology. Check out the [Discussion Board](https://github.com/pytorch/executorch/discussions) or chat real time with us on [Discord](https://discord.gg/MeacgB7A)
## Contributing
-See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code
-style, CI jobs, and other development topics.
-
-To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
-* Head to the `#executorch-general` channel for general questions, discussion, and community support.
-* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development.
+We welcome contributions. To get started review the [guidelines](CONTRIBUTING.md) and chat with us on [Discord](https://discord.gg/MeacgB7A)
## Directory Structure
diff --git a/backends/arm/README.md b/backends/arm/README.md
index 2079e8ddd8..e28559fb90 100644
--- a/backends/arm/README.md
+++ b/backends/arm/README.md
@@ -122,6 +122,18 @@ The you can run the tests with
pytest -c /dev/null -v -n auto backends/arm/test --arm_run_corstoneFVP
```
+## Passes
+
+With the default passes in the Arm Ethos-U backend, assuming the model lowers fully to the
+Ethos-U, the exported program is composed of a Quantize node, Ethos-U custom delegate
+and a Dequantize node. In some circumstances, you may want to feed quantized input to the Neural
+Network straight away, e.g. if you have a camera sensor outputting (u)int8 data and keep all the
+arithmetic of the application in the int8 domain. For these cases, you can apply the
+`exir/passes/quantize_io_pass.py`. See the unit test in `executorch/backends/arm/
+test/passes/test_ioquantization_pass.py`for an example how to feed quantized inputs and
+obtain quantized outputs.
+
+
### Code coverage
To get code coverage:
diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py
index a3d168fb87..ce15d8298c 100644
--- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py
+++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py
@@ -116,7 +116,7 @@ def insert_input_transpose(node, input_node, graph_module):
with graph_module.graph.inserting_before(node):
permute_node = create_node(
graph_module.graph,
- torch.ops.passthrough_to_tosa._transpose,
+ torch.ops.passthrough_to_tosa._transpose.default,
args=(
input_node,
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
@@ -129,18 +129,22 @@ def insert_input_transpose(node, input_node, graph_module):
permute_node.meta["tosa_dim_order"] = tuple(
range(len(input_node.meta["val"].size()))
)
+ permute_node.meta["val"] = input_node.meta["val"]
@staticmethod
def insert_output_transpose(node, graph_module):
with graph_module.graph.inserting_after(node):
permute_node = create_node(
graph_module.graph,
- torch.ops.passthrough_to_tosa._transpose,
+ torch.ops.passthrough_to_tosa._transpose.default,
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
)
permute_node.meta["tosa_dim_order"] = (
AnnotateChannelsLastDimOrder.NHWC_order
)
+ permute_node.meta["val"] = node.meta["val"].permute(
+ AnnotateChannelsLastDimOrder.NHWC_order
+ )
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
users = [user for user in node.users if user != permute_node]
for user in users:
@@ -209,7 +213,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
else:
- dim_order = tuple(range(node_data.dim()))
+ dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py
index 0846d97372..3feb0a0e05 100644
--- a/backends/arm/_passes/annotate_decomposed_matmul.py
+++ b/backends/arm/_passes/annotate_decomposed_matmul.py
@@ -6,9 +6,12 @@
import itertools
+from typing import List
+
import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
-from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
+
+from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
@@ -24,6 +27,22 @@ class AnnotateDecomposedMatmulPass(ExportPass):
matmul-op (can be mm or bmm).
"""
+ def _match_partition_to_node(
+ self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
+ ) -> torch.fx.Node:
+ """
+ The partition.input_nodes order is not guaranteed. Compare these
+ with the matmul node inputs coming in and return the nodes
+ in the correct order.
+ """
+ if not node or node in partitioned_inputs or node.op == "placeholder":
+ return node
+ else:
+ return self._match_partition_to_node(
+ node.all_input_nodes[0], partitioned_inputs
+ )
+ raise RuntimeError(f"Cannot find an input node which matches, {node}.")
+
def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
@@ -45,28 +64,36 @@ def call(self, graph_module: GraphModule) -> PassResult:
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]
+
if quantized_input:
matmul_args = matmul_node.all_input_nodes
- for i in range(len(matmul_args)):
- input_node = partition.input_nodes[i]
- matmul_input_node = matmul_args[i]
+ for node in matmul_args:
+ input_node = self._match_partition_to_node(
+ node, partition.input_nodes
+ )
+
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
- input_node_qargs = input_node.args[1:]
+ input_node_qargs = QuantArgs.from_operator(
+ input_node.target, input_node.args
+ )
+
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
- dq_node.args = (matmul_input_node, *input_node_qargs)
- matmul_node.replace_input_with(matmul_input_node, dq_node)
+ dq_node.args = (node, *input_node_qargs)
+ matmul_node.replace_input_with(node, dq_node)
partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
- output_node_qargs = partition_output.args[1:]
+ output_node_qargs = QuantArgs.from_operator(
+ partition_output.target, partition_output.args
+ )
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py
index 9bac3b037c..686bfbcd8a 100644
--- a/backends/arm/_passes/arm_pass_manager.py
+++ b/backends/arm/_passes/arm_pass_manager.py
@@ -21,26 +21,32 @@
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
-from executorch.backends.arm._passes.convert_squeezes_to_view import (
+from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
ConvertSqueezesToViewPass,
)
+from executorch.backends.arm._passes.decompose_batchnorm_pass import (
+ DecomposeBatchNormPass,
+)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_layernorm_pass import (
DecomposeLayerNormPass,
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
-from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
+from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
+ DecomposeSelectPass,
+)
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
- QuantizeFullArgument,
+ QuantizeOperatorArguments,
RetraceFoldedDtypesPass,
)
-from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
+from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
+from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
@@ -48,10 +54,12 @@
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
-from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
+from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
ConvertMeanDimToAveragePoolPass,
)
-from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
+from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
+ ConvertMmToBmmPass,
+)
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
ScalarsToAttributePass,
@@ -82,14 +90,15 @@ def _transform(self, graph_module: GraphModule):
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
+ self.add_pass(DecomposeBatchNormPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(AnnotateDecomposedMatmulPass())
- self.add_pass(QuantizeFullArgument())
- self.add_pass(FoldAndAnnotateQParamsPass())
+ self.add_pass(QuantizeOperatorArguments())
+ self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
@@ -116,16 +125,18 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
+ self.add_pass(DecomposeBatchNormPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxesPass())
+ self.add_pass(FuseBatchnorm2DPass(exported_program))
self.add_pass(AnnotateDecomposedMatmulPass())
- self.add_pass(QuantizeFullArgument())
- self.add_pass(FoldAndAnnotateQParamsPass())
+ self.add_pass(QuantizeOperatorArguments())
+ self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py
index 7377d401ab..cb43acc7fd 100644
--- a/backends/arm/_passes/arm_pass_utils.py
+++ b/backends/arm/_passes/arm_pass_utils.py
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -58,9 +58,9 @@ def get_param_tensor(
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
- return getattr(node.graph.owning_module, node.target)
+ return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
except AttributeError:
- return getattr(exp_prog.graph_module, node.target)
+ return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
raise RuntimeError(f"unsupported param type, {node.op}.")
@@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
- return args.get(key, default_value) # pyre-ignore[16]
+ return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16]
elif isclass(key):
for arg in args:
if isinstance(arg, key):
diff --git a/backends/arm/_passes/decompose_batchnorm_pass.py b/backends/arm/_passes/decompose_batchnorm_pass.py
new file mode 100644
index 0000000000..d33e8e3b51
--- /dev/null
+++ b/backends/arm/_passes/decompose_batchnorm_pass.py
@@ -0,0 +1,138 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-unsafe
+
+import operator
+
+import torch
+from executorch.backends.arm._passes.arm_pass_utils import create_node
+from executorch.exir.dialects._ops import ops as exir_ops
+from executorch.exir.pass_base import ExportPass, PassResult
+
+
+edge_bn_ops = (exir_ops.edge.aten._native_batch_norm_legit_no_training.default,)
+
+
+def get_bn_decomposition(op) -> tuple:
+ """
+ Returns decomposition of batchnorm in edge ops.
+ Raises RuntimeError if op is not batchnorm edge op.
+ """
+ if op in edge_bn_ops:
+ return (
+ exir_ops.edge.aten.sub.Tensor,
+ exir_ops.edge.aten.add.Tensor,
+ exir_ops.edge.aten.rsqrt.default,
+ exir_ops.edge.aten.mul.Tensor,
+ exir_ops.edge.aten.view_copy.default,
+ exir_ops.edge.aten.full.default,
+ )
+ else:
+ raise RuntimeError(f"Can't get decomposition for {op}")
+
+
+class DecomposeBatchNormPass(ExportPass):
+ """
+ Decompose BatchNorm to:
+ %output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta
+ e.g.
+ %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias
+ ->
+ %op1 = sub(%activations, %running_mean)
+ %op2 = add(%running_var, %epsilon_const)
+ %op3 = rsqrt(%op2)
+ %op4 = mul(%op1, %op3)
+ %op5 = mul(%op4, %weights)
+ %output = add(%op5, %bias)
+ """
+
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+ modified = False
+ for node in graph_module.graph.nodes:
+ if node.op != "call_function" or node.target not in edge_bn_ops:
+ continue
+
+ args = node.args
+ meta = node.meta
+ (
+ activations,
+ weights,
+ bias,
+ running_mean,
+ running_var,
+ momentum,
+ epsilon,
+ ) = args
+ if momentum != 0.1:
+ raise RuntimeError(f"Expected momenttum=0.1 but got {momentum}")
+
+ shape = meta["val"][0].size()
+ dtype = meta["val"][0].dtype
+ rank = len(shape)
+ running_mean_shape = running_mean.meta["val"].shape
+ running_mean_reshaped_shape = [1] * rank
+ running_mean_reshaped_shape[1] = running_mean_shape[0]
+ epsilon_reshaped_shape = [1] * rank
+
+ sub, add, rsqrt, mul, view, full = get_bn_decomposition(node.target)
+ with graph_module.graph.inserting_before(node):
+ mean_reshaped = create_node(
+ graph_module.graph,
+ view,
+ args=(running_mean, running_mean_reshaped_shape),
+ )
+ op1 = create_node(
+ graph_module.graph, sub, args=(activations, mean_reshaped)
+ )
+ full = create_node(
+ graph_module.graph,
+ full,
+ args=(epsilon_reshaped_shape, epsilon),
+ kwargs={"dtype": dtype},
+ )
+ var_reshaped = create_node(
+ graph_module.graph,
+ view,
+ args=(running_var, running_mean_reshaped_shape),
+ )
+ op2 = create_node(graph_module.graph, add, args=(var_reshaped, full))
+ op3 = create_node(graph_module.graph, rsqrt, args=(op2,))
+ op4 = create_node(graph_module.graph, mul, args=(op1, op3))
+ if weights is not None:
+ weights_reshaped = create_node(
+ graph_module.graph,
+ view,
+ args=(weights, running_mean_reshaped_shape),
+ )
+ op5 = create_node(
+ graph_module.graph, mul, args=(op4, weights_reshaped)
+ )
+ else:
+ op5 = op4
+ output = op5
+ if bias is not None:
+ bias_reshaped_shape = running_mean_reshaped_shape
+ bias_reshaped = create_node(
+ graph_module.graph, view, args=(bias, bias_reshaped_shape)
+ )
+ output = create_node(
+ graph_module.graph, add, args=(op5, bias_reshaped)
+ )
+
+ users = [user for user in node.users if node != user]
+ node.replace_all_uses_with(output)
+ for user in users:
+ if user.target == operator.getitem:
+ user.replace_all_uses_with(output)
+ graph_module.graph.erase_node(node)
+ graph_module.graph.eliminate_dead_code()
+ modified = True
+ if modified:
+ graph_module.recompile()
+ graph_module = super().call(graph_module).graph_module
+
+ return PassResult(graph_module, modified)
diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py
index 3739337101..cc4a81caae 100644
--- a/backends/arm/_passes/decompose_layernorm_pass.py
+++ b/backends/arm/_passes/decompose_layernorm_pass.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -82,9 +82,10 @@ def call(self, graph_module: torch.fx.GraphModule):
n_dims = len(normalized_shape)
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
+ dtype = meta["val"][0].dtype
else:
shape = meta["val"].size()
- dtype = meta["val"][0].dtype
+ dtype = meta["val"].dtype
rank = len(shape)
dims = list(range(-1, -1 * (n_dims + 1), -1))
dims = [dim % rank for dim in dims]
diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
index b1e680b7bc..29791940d5 100644
--- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
+++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
@@ -105,21 +105,6 @@ def fold_and_annotate_arg(
for arg in arg_list:
if not isinstance(arg, Node):
return
- """
- Make sure arg has requires_grad set to False
- For parameters that are not quantized, sometimes (i.e. convolution)
- the Parameter(FakeTensor(...)) has requires_grad set to True, which
- causes the retracing of the graph to fail with:
-
- E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
- E
- E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
- E Original traceback:
- E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
- E x = conv(x)
- """
- if arg.op == "placeholder":
- arg.meta["val"].requires_grad = False
arg_quant_params = None
if arg.target == dq_op:
@@ -134,7 +119,7 @@ def fold_and_annotate_arg(
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
- n.replace_all_uses_with(n.args[0])
+ n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
graph_module.graph.erase_node(n)
def call(self, graph_module: GraphModule) -> PassResult:
@@ -182,11 +167,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
return PassResult(graph_module, True)
-class QuantizeFullArgument(ExportPass):
+class QuantizeOperatorArguments(ExportPass):
"""
- Make sure the fill_value for full.default is quantized. This pass needs to be run before
- the folding pass above to make sure that the retraced output of the full.default op is
- the right dtype.
+ This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
+ More specifically, this pass:
+ - Makes sure the fill_value for full.default is quantized. This pass needs to be run before
+ the folding pass above to make sure that the retraced output of the full.default op is
+ the right dtype.
+ - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
"""
def call(self, graph_module: GraphModule) -> PassResult:
@@ -194,7 +182,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Loop over the graph nodes and find full.default nodes.
for n in graph_module.graph.nodes:
n = cast(Node, n)
- if n.target != exir_ops.edge.aten.full.default:
+ if n.target not in {
+ exir_ops.edge.aten.clamp.default,
+ exir_ops.edge.aten.full.default,
+ }:
continue
# Make sure we have a quantized operator
@@ -203,13 +194,29 @@ def call(self, graph_module: GraphModule) -> PassResult:
continue
qargs = QuantArgs.from_operator(user.target, user.args)
- if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
- # replace the node arg with a quantized dito and also set dtype
- # to get the right output according to the Edge IR specification:
- # exir/dialects/edge/edge.yaml:3596
- quantized_full_value = qargs.quantize_value(n.args[1]).item()
- n.update_arg(1, quantized_full_value)
- n.update_kwarg("dtype", qargs.dtype)
+
+ if n.target == exir_ops.edge.aten.full.default:
+ if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
+ # replace the node arg with a quantized dito and also set dtype
+ # to get the right output according to the Edge IR specification:
+ # exir/dialects/edge/edge.yaml:3596
+ quantized_full_value = qargs.quantize_value(n.args[1]).item()
+ n.update_arg(1, quantized_full_value)
+ n.update_kwarg("dtype", qargs.dtype)
+ modified = True
+ elif n.target == exir_ops.edge.aten.clamp.default:
+ # Quantize the min and max arguments of clamp, if they are not None
+ min_val = n.args[1]
+ max_val = None if len(n.args) <= 2 else n.args[2]
+
+ if min_val is not None:
+ quantized_min_val = qargs.quantize_value(min_val).item()
+ n.update_arg(1, quantized_min_val)
+
+ if max_val is not None:
+ quantized_max_val = qargs.quantize_value(max_val).item()
+ n.update_arg(2, quantized_max_val)
+
modified = True
return PassResult(graph_module, modified)
diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py
new file mode 100644
index 0000000000..6a5ece2e44
--- /dev/null
+++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py
@@ -0,0 +1,128 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from executorch.exir import ExportedProgram
+from executorch.exir.dialects._ops import ops as exir_ops
+from executorch.exir.pass_base import ExportPass, PassResult
+from torch._export.utils import get_buffer, get_param
+from torch.fx import Node
+from torch.nn.utils.fusion import fuse_conv_bn_weights
+
+
+class FuseBatchnorm2DPass(ExportPass):
+ """Fuses the pattern convolution -> batchnorm by updating
+ the weights and bias of the convolution and removing the batchnorm.
+ """
+
+ def __init__(self, exported_program: ExportedProgram):
+ self.exported_program = exported_program
+ super().__init__()
+
+ def is_fuseable_conv_bn(self, node: Node):
+ """Returns True if node is a batchnorm that can be fused into
+ a parent convolution."""
+ if node.op != "call_function":
+ return False
+ if node.target not in (
+ exir_ops.edge.aten._native_batch_norm_legit,
+ exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
+ ):
+ return False
+ conv = node.all_input_nodes[0]
+ if conv.target != exir_ops.edge.aten.convolution.default:
+ return False
+ # Batchnorm users are getitem, we can only handle those that get first element.
+ for user in node.users:
+ get_index = user.args[1]
+ if get_index != 0:
+ return False
+ # Since we change the output of the conv, fuse only if it has single user.
+ if len(conv.users) > 1:
+ return False
+ # For similar reasons, only fuse if conv parameters have single user.
+ if len(conv.all_input_nodes[1].users) > 1:
+ return False
+ if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
+ return False
+ return True
+
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
+ modified = False
+ for node in graph_module.graph.nodes:
+ if not self.is_fuseable_conv_bn(node):
+ continue
+
+ def get_param_or_none(arg) -> torch.nn.Parameter | None:
+ """get_param but check if arg is none first."""
+ return (
+ get_param(self.exported_program, arg) if arg is not None else None
+ )
+
+ # Get weight, bias, mean, var and epsilon from the batchnorm
+ bn = node
+ conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
+ bn_weight = get_param_or_none(bn_weight_node)
+ bn_bias = get_param_or_none(bn_bias_node)
+
+ running_mean = get_buffer(self.exported_program, bn_mean_node)
+ running_var = get_buffer(self.exported_program, bn_var_node)
+ if running_mean is None or running_var is None:
+ raise ValueError(
+ "Parameters running_mean and running_var of batchnorm can't be None."
+ )
+ epsilon = bn.args[-1]
+
+ # Get weight and bias from conv
+ conv_weight_node, conv_bias_node = conv.args[1:3]
+ conv_weight = get_param(self.exported_program, conv_weight_node)
+ conv_bias = get_param_or_none(conv_bias_node)
+ if conv_weight is None:
+ raise ValueError("Parameter weight of convolution can't be None.")
+
+ # Compute conv parameters folded with batchnorm
+ fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
+ conv_weight,
+ conv_bias,
+ running_mean,
+ running_var,
+ epsilon,
+ bn_weight,
+ bn_bias,
+ )
+
+ # Set the conv parameters to fused value
+ def try_set_param(
+ param_node: Node | None, param_value: torch.nn.Parameter
+ ) -> bool:
+ """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
+ if param_node is not None:
+ param_name = (
+ self.exported_program.graph_signature.inputs_to_parameters[
+ param_node.name
+ ]
+ )
+ self.exported_program.state_dict[param_name] = param_value
+ return True
+ return False
+
+ try_set_param(conv_weight_node, fused_conv_weight)
+ if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
+ bn_bias_node, fused_conv_bias
+ ):
+ # Conv didn't have bias but batchnorm did, steal bias from batchnorm.
+ conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
+ conv.args = conv_args
+
+ # Erasing nodes is handled by dead-code elimination.
+ for user in bn.users:
+ user.replace_all_uses_with(conv)
+ modified = True
+
+ if modified:
+ graph_module.graph.eliminate_dead_code()
+ graph_module.recompile()
+ graph_module = super().call(graph_module).graph_module
+ return PassResult(graph_module=graph_module, modified=modified)
diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py
index 57a8376d40..b500540ffb 100644
--- a/backends/arm/_passes/insert_table_ops.py
+++ b/backends/arm/_passes/insert_table_ops.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -31,7 +31,7 @@ class InsertTableOpsPass(ExportPass):
"""
For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
- When loweringthe _table node target_str will be used to find the corresponding torch operator
+ When lowering the _table node target_str will be used to find the corresponding torch operator
which will be used to produce the table values in operators/op_table.py.
"""
@@ -42,6 +42,8 @@ class InsertTableOpsPass(ExportPass):
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.tanh.default: torch.tanh,
+ exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
+ exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
}
def __init__(self, exported_program: ExportedProgram) -> None:
@@ -92,7 +94,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
with graph_module.graph.inserting_before(node):
table_node = create_node(
graph=graph_module.graph,
- op_target=torch.ops.tosa._table,
+ op_target=torch.ops.tosa._table.default,
args=(node.args[0],),
)
assert len(input_qparams) == 1
@@ -104,7 +106,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
out_quantargs=output_qparams[0],
)
# Register buffer in self.exported_program.state_dict
- self.register_buffer(buffer_name=table_node.name, buffer=buffer)
+ # When the graph is retraced, the implementation _table is used and the suffix _default disappears from the node name
+ # Remove it here to make it possible to find in the node_visitor
+ self.register_buffer(
+ buffer_name=table_node.name.replace("_default", ""), buffer=buffer
+ )
node.replace_all_uses_with(table_node)
graph_module.graph.erase_node(node)
table_node.meta["input_qparams"] = input_qparams
diff --git a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
index f4d369a504..ad95379cc8 100644
--- a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
+++ b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -66,7 +66,7 @@ def call(self, graph_module: torch.fx.GraphModule):
sum_node = cast(torch.fx.Node, node)
keep_dim = get_node_arg(
# pyre-ignore[6]
- sum_node.args,
+ sum_node.args, # type: ignore[arg-type]
keep_dim_index,
False,
)
@@ -74,7 +74,7 @@ def call(self, graph_module: torch.fx.GraphModule):
if keep_dim:
continue
- dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6]
+ dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6]
# Add keep_dim = True arg to sum node.
set_node_arg(sum_node, 2, True)
diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py
index f6fe02b6eb..78865fe33f 100644
--- a/backends/arm/_passes/scalars_to_attribute_pass.py
+++ b/backends/arm/_passes/scalars_to_attribute_pass.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -54,7 +54,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
):
- new_args.append(arg)
+ new_args.append(arg) # type: ignore[arg-type]
continue
prefix = "_tensor_constant_"
diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py
index d695aec2fd..899bafcf04 100644
--- a/backends/arm/arm_backend.py
+++ b/backends/arm/arm_backend.py
@@ -15,7 +15,7 @@
import os
from typing import cast, final, List, Optional
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors
@@ -230,7 +230,7 @@ def preprocess( # noqa: C901
# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
- graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
+ graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
exported_program=edge_program
)
diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py
index cc4058c4c5..8fde8dff61 100644
--- a/backends/arm/arm_partitioner.py
+++ b/backends/arm/arm_partitioner.py
@@ -10,7 +10,7 @@
from typing import Callable, final, List, Optional, Tuple
import torch
-from executorch.backends.arm.arm_backend import (
+from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
ArmBackend,
) # usort: skip
from executorch.backends.arm.operator_support.tosa_supported_operators import (
@@ -113,8 +113,41 @@ def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
+ ops_to_not_decompose_if_quant_op = [
+ torch.ops.aten.hardsigmoid.default,
+ torch.ops.aten.hardswish.default,
+ ]
+
+ def filter_fn(node: torch.fx.Node) -> bool:
+ # This function filters for operators to not decompose where:
+ # - It's target is in ops_to_not_decompose_if_quant_op list.
+ # - All it's inputs/outputs are quantize operators.
+ dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
+ q = torch.ops.quantized_decomposed.quantize_per_tensor.default
+
+ if node.target in ops_to_not_decompose_if_quant_op:
+ # Assume we should not decompose the operator (it is quantized)
+ should_not_decompose = True
+
+ input_nodes = node.all_input_nodes
+ ouput_nodes = node.users
+
+ for inp in input_nodes:
+ if inp.target != dq:
+ should_not_decompose = False
+
+ for out in ouput_nodes:
+ if out.target != q:
+ should_not_decompose = False
+
+ return should_not_decompose
+
+ # Be default, do not decompose the operator
+ return True
+
ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest2d.vec,
- ]
- return (ops_to_not_decompose, None)
+ ] + ops_to_not_decompose_if_quant_op
+
+ return (ops_to_not_decompose, filter_fn)
diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py
index 918d95ba37..f7f0c4b49c 100644
--- a/backends/arm/arm_vela.py
+++ b/backends/arm/arm_vela.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -12,7 +12,7 @@
from typing import List
import numpy as np
-from ethosu.vela import vela
+from ethosu.vela import vela # type: ignore
# Pack either input or output tensor block, compose the related arrays into
@@ -96,13 +96,13 @@ def vela_compile(tosa_graph, args: List[str], shape_order=None):
block_name = block_name + b"\x00" * (16 - len(block_name))
# We need the acual unpadded block lengths for hw setup
- block_length = struct.pack(" bool
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
- f"{node.target.name()}." # pyre-ignore[16]
+ f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
)
return False
@@ -107,7 +107,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
f"Output dtype {output_val.dtype} is not supported in "
- f"{node.target.name()} for input dtype {input_dtype}. " # pyre-ignore[16]
+ f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
)
@@ -118,7 +118,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
f"Argument 'memory_format' is not supported for "
- f"{node.target.name()} right now." # pyre-ignore[16]
+ f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False
@@ -126,10 +126,10 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
- if dim_order != list(range(len(dim_order))):
+ if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
logger.info(
f"Argument {dim_order=} is not supported for "
- f"{node.target.name()} right now." # pyre-ignore[16]
+ f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False
diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py
index c3102a86a4..36914579fe 100644
--- a/backends/arm/operator_support/tosa_supported_operators.py
+++ b/backends/arm/operator_support/tosa_supported_operators.py
@@ -76,9 +76,12 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
+ exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
+ exir_ops.edge.aten.hardsigmoid.default,
exir_ops.edge.aten.hardtanh.default,
+ exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.eq.Tensor,
@@ -137,5 +140,5 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
def is_node_supported_custom(self, node: fx.Node) -> bool:
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
if node.target in tosa_checks.keys():
- return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
+ return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
return False
diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py
index 5a97d33304..f57ba092bc 100644
--- a/backends/arm/operators/__init__.py
+++ b/backends/arm/operators/__init__.py
@@ -9,9 +9,9 @@
node_visitor,
op_add,
op_avg_pool2d,
- op_batch_norm,
op_bmm,
op_cat,
+ op_clamp,
op_conv2d,
op_eq,
op_exp,
diff --git a/backends/arm/operators/node_visitor.py b/backends/arm/operators/node_visitor.py
index 8609e5e391..afb5f93baa 100644
--- a/backends/arm/operators/node_visitor.py
+++ b/backends/arm/operators/node_visitor.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
from typing import Dict, List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -44,7 +44,7 @@ def define_node(
# container for all node visitors
-_node_visitor_dicts = {
+_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
}
diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py
index 74f00354ed..ccdeb2c1bc 100644
--- a/backends/arm/operators/op_add.py
+++ b/backends/arm/operators/op_add.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
@@ -75,7 +75,7 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
- tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
+ tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
@register_node_visitor
diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py
index fecddac659..e300b3ed01 100644
--- a/backends/arm/operators/op_avg_pool2d.py
+++ b/backends/arm/operators/op_avg_pool2d.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`
diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py
deleted file mode 100644
index ce5998cb72..0000000000
--- a/backends/arm/operators/op_batch_norm.py
+++ /dev/null
@@ -1,211 +0,0 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-
-# pyre-unsafe
-from typing import List
-
-import serializer.tosa_serializer as ts
-import torch
-from executorch.backends.arm.operators.node_visitor import (
- NodeVisitor,
- register_node_visitor,
-)
-from executorch.backends.arm.tosa_mapping import TosaArg
-from executorch.backends.arm.tosa_specification import TosaSpecification
-from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
-from serializer.tosa_serializer import TosaOp
-
-
-@register_node_visitor
-class BatchNormVisitor(NodeVisitor):
- target = "aten._native_batch_norm_legit_no_training.default"
-
- tosa_specs = [
- TosaSpecification.create_from_string("TOSA-0.80+MI"),
- ]
-
- def __init__(self, *args):
- super().__init__(*args)
-
- # For BatchNorm2D, mean and var are calculated over the channel dimension
- # But TOSA doesn't allow subtraction of inputs with different ranks
- # Need to augment the shapes to match the ranks with activations
- def augment_shape_rank(self, shape, dim_order):
- nchw_shape = (1, *shape, 1, 1)
- return tosa_shape(nchw_shape, dim_order)
-
- def define_node(
- self,
- node: torch.fx.Node,
- tosa_graph: ts.TosaSerializer,
- inputs: List[TosaArg],
- output: TosaArg,
- ) -> None:
- # Decompose batch norm into sequence
- (activations, weights, bias, running_mean, running_var, momentum, epsilon) = (
- inputs
- )
-
- input_dtype = activations.dtype
-
- assert (
- 0.1 == momentum.number
- ), "Expected 0.1 momentum, not currently encoded into TOSA"
-
- # %output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta
- # e.g.
- # %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias
- # ->
- # %op1 = tosa.SUB(%activations, %running_mean)
- # %op2 = tosa.ADD(%running_var, %epsilon_const)
- # %op3 = tosa.RSQRT(%op2)
- # %op4 = tosa.MUL(%op1, %op3)
- # %op5 = tosa.MUL(%op4, %weights)
- # %output = tosa.ADD(%op5, %bias)
-
- # Reshape mean to match rank of activations
- mean_reshaped = promote_shape(
- tosa_graph,
- running_mean,
- self.augment_shape_rank(running_mean.shape, output.dim_order),
- input_dtype,
- )
-
- # Subtract mean
- # %op1 = tosa.SUB(%activations, %running_mean)
- op1 = tosa_graph.addIntermediate(
- tosa_shape(output.shape, output.dim_order), input_dtype
- )
- tosa_graph.addOperator(
- TosaOp.Op().SUB,
- [activations.name, mean_reshaped.name],
- [op1.name],
- )
- # Adding eplison to variance
- # %op2 = tosa.ADD(%running_var, %epsilon_const)
- epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number])
- op2 = tosa_graph.addIntermediate(
- tosa_shape(running_var.shape, running_var.dim_order), input_dtype
- )
- tosa_graph.addOperator(
- TosaOp.Op().ADD,
- [running_var.name, epsilon_const.name],
- [op2.name],
- )
- # Push downward the variance
- # %op3 = tosa.RSQRT(%op2)
- op3 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
- tosa_graph.addOperator(TosaOp.Op().RSQRT, [op2.name], [op3.name])
-
- # Reshape variable to match rank of activations
- op3_reshaped = promote_shape(
- tosa_graph,
- op3,
- self.augment_shape_rank(running_var.shape, output.dim_order),
- input_dtype,
- )
-
- # Handle non existing weights and bias
- if not weights.name and not bias.name:
- # Multiply shifted activations with reciprocal variance
- # %output = tosa.MUL(%op1, %op3) e.g. Now we have %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const )
- attr_mul = ts.TosaSerializerAttribute()
- attr_mul.MulAttribute(0)
- tosa_graph.addOperator(
- TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [output.name], attr_mul
- )
- return
- else:
- # Multiply shifted activations with reciprocal variance
- # %op4 = tosa.MUL(%op1, %op3)
- op4 = tosa_graph.addIntermediate(
- tosa_shape(output.shape, output.dim_order), input_dtype
- )
- attr_mul = ts.TosaSerializerAttribute()
- attr_mul.MulAttribute(0)
- tosa_graph.addOperator(
- TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [op4.name], attr_mul
- )
-
- # Now we have %op4 = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const )
-
- if weights.name and not bias.name:
- # Handle only weights but no bias
-
- # Reshape weights to match rank of activations
- weights_reshaped = promote_shape(
- tosa_graph,
- weights,
- self.augment_shape_rank(weights.shape, output.dim_order),
- input_dtype,
- )
-
- # %output = tosa.MUL(%op4, %weights)
- attr_mul = ts.TosaSerializerAttribute()
- attr_mul.MulAttribute(0)
- tosa_graph.addOperator(
- TosaOp.Op().MUL,
- [op4.name, weights_reshaped.name],
- [output.name],
- attr_mul,
- )
- return
-
- if not weights.name and bias.name:
- # Handle only bias but no weights
-
- # Reshape bias to match rank of activations
- bias_reshaped = promote_shape(
- tosa_graph,
- bias,
- self.augment_shape_rank(bias.shape, output.dim_order),
- input_dtype,
- )
-
- # %output = tosa.ADD(%op4, %bias)
- tosa_graph.addOperator(
- TosaOp.Op().ADD,
- [op4.name, bias_reshaped.name],
- [output.name],
- )
- return
-
- # We have both weights and bias
-
- # Reshape weights to match rank of activations
- weights_reshaped = promote_shape(
- tosa_graph,
- weights,
- self.augment_shape_rank(weights.shape, output.dim_order),
- input_dtype,
- )
-
- # %op5 = tosa.MUL(%op4, %weights)
- op5 = tosa_graph.addIntermediate(
- tosa_shape(output.shape, output.dim_order), input_dtype
- )
- attr_mul = ts.TosaSerializerAttribute()
- attr_mul.MulAttribute(0)
- tosa_graph.addOperator(
- TosaOp.Op().MUL,
- [op4.name, weights_reshaped.name],
- [op5.name],
- attr_mul,
- )
-
- # Reshape bias to match rank of activations
- bias_reshaped = promote_shape(
- tosa_graph,
- bias,
- self.augment_shape_rank(bias.shape, output.dim_order),
- input_dtype,
- )
-
- # %output = tosa.ADD(%op5, %bias)
- tosa_graph.addOperator(
- TosaOp.Op().ADD,
- [op5.name, bias_reshaped.name],
- [output.name],
- )
diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py
index 83d3df2701..d3261ebde0 100644
--- a/backends/arm/operators/op_bmm.py
+++ b/backends/arm/operators/op_bmm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -7,7 +7,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
@@ -75,14 +75,14 @@ def define_node(
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
final_output_scale = (
- input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
+ input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
) / output_qparams.scale
build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
- input_node=bmm_result,
+ input_node=bmm_result, # type: ignore[possibly-undefined]
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=bmm_result.shape,
diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py
index e249942d0b..f786395cc3 100644
--- a/backends/arm/operators/op_cat.py
+++ b/backends/arm/operators/op_cat.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py
new file mode 100644
index 0000000000..486da27c9a
--- /dev/null
+++ b/backends/arm/operators/op_clamp.py
@@ -0,0 +1,144 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree
+
+from typing import Any, List, Tuple
+
+import serializer.tosa_serializer as ts # type: ignore
+
+import torch
+from executorch.backends.arm.operators.node_visitor import (
+ NodeVisitor,
+ register_node_visitor,
+)
+
+from executorch.backends.arm.tosa_mapping import TosaArg
+from executorch.backends.arm.tosa_specification import TosaSpecification
+from serializer.tosa_serializer import TosaOp
+from torch.fx import Node
+
+
+@register_node_visitor
+class ClampVisitor_080_BI(NodeVisitor):
+ target = "aten.clamp.default"
+
+ tosa_specs = [
+ TosaSpecification.create_from_string("TOSA-0.80+BI"),
+ ]
+
+ def __init__(self, *args):
+ super().__init__(*args)
+
+ def _create_clamp_node(
+ self,
+ tosa_graph: ts.TosaSerializer,
+ input_name: str,
+ output_name: str,
+ min_int: int,
+ max_int: int,
+ min_fp32: float,
+ max_fp32: float,
+ ) -> None:
+ attr = ts.TosaSerializerAttribute()
+ attr.ClampAttribute(
+ tosa_graph.builder,
+ min_int,
+ max_int,
+ min_fp32,
+ max_fp32,
+ )
+ tosa_graph.addOperator(TosaOp.Op().CLAMP, [input_name], [output_name], attr)
+
+ def _get_min_max_arguments(
+ self, node: Node, dtype_min: int | float, dtype_max: int | float
+ ) -> Tuple[int | float, int | float]:
+
+ def cast_type(value: Any) -> int | float:
+ if isinstance(value, int):
+ return value
+ else:
+ # Attempt to cast to float
+ return float(value)
+
+ assert 2 <= len(node.args) <= 3
+
+ min_arg = dtype_min
+ max_arg = dtype_max
+
+ if node.args[1] is not None:
+ min_arg = cast_type(node.args[1])
+
+ if len(node.args) > 2:
+ if node.args[2] is not None:
+ max_arg = cast_type(node.args[2])
+
+ return min_arg, max_arg
+
+ def define_node(
+ self,
+ node: Node,
+ tosa_graph: ts.TosaSerializer,
+ inputs: List[TosaArg],
+ output: TosaArg,
+ ) -> None:
+ assert len(node.all_input_nodes) == 1
+
+ min_int8, max_int8 = self._get_min_max_arguments(
+ node,
+ torch.iinfo(torch.int8).min,
+ torch.iinfo(torch.int8).max,
+ )
+
+ # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
+ self._create_clamp_node(
+ tosa_graph,
+ inputs[0].name,
+ output.name,
+ int(min_int8),
+ int(max_int8),
+ 0,
+ 0,
+ )
+
+
+@register_node_visitor
+class ClampVisitor_080_MI(ClampVisitor_080_BI):
+ # inheriting 'target' from BI class
+
+ tosa_specs = [
+ TosaSpecification.create_from_string("TOSA-0.80+MI"),
+ ]
+
+ def __init__(self, *args):
+ super().__init__(*args)
+
+ def define_node(
+ self,
+ node: Node,
+ tosa_graph: ts.TosaSerializer,
+ inputs: List[TosaArg],
+ output: TosaArg,
+ ) -> None:
+ assert len(node.all_input_nodes) == 1
+
+ if inputs[0].dtype == ts.DType.INT8:
+ # Call the inherited define_node for handling integers
+ super().define_node(node, tosa_graph, inputs, output)
+ else:
+ min_fp32, max_fp32 = self._get_min_max_arguments(
+ node,
+ torch.finfo(torch.float32).min,
+ torch.finfo(torch.float32).max,
+ )
+
+ self._create_clamp_node(
+ tosa_graph,
+ inputs[0].name,
+ output.name,
+ 0,
+ 0,
+ min_fp32,
+ max_fp32,
+ )
diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py
index 42156da013..f97e408a02 100644
--- a/backends/arm/operators/op_conv2d.py
+++ b/backends/arm/operators/op_conv2d.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
@@ -165,13 +165,13 @@ def define_node(
# integer value domain of the next op. Otherwise return float32 output.
if inputs[0].dtype == ts.DType.INT8:
# Get scale_factor from input, weight, and output.
- input_scale = input_qparams[0].scale # pyre-ignore [61]
+ input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61]
weight_scale = input_qparams[1].scale # pyre-ignore [61]
output_qargs = get_output_qparams(node) # pyre-ignore [16]
build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
- conv2d_res,
+ conv2d_res, # type: ignore[possibly-undefined]
output.name,
output.dtype,
input_scale,
diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py
index e6e2492aec..02fc89099e 100644
--- a/backends/arm/operators/op_eq.py
+++ b/backends/arm/operators/op_eq.py
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py
index 46f4980975..4b8232ef6e 100644
--- a/backends/arm/operators/op_exp.py
+++ b/backends/arm/operators/op_exp.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py
index 7964e58226..f06b9873e6 100644
--- a/backends/arm/operators/op_full.py
+++ b/backends/arm/operators/op_full.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -8,7 +8,7 @@
import numpy as np
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
@@ -40,7 +40,7 @@ def define_node(
if output.dtype == ts.DType.INT8:
fill_dtype = np.int8
else:
- fill_dtype = np.float32
+ fill_dtype = np.float32 # type: ignore[assignment]
data = np.full(shape, value, dtype=fill_dtype)
tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const")
diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py
index 810b40bb1a..e4de12f332 100644
--- a/backends/arm/operators/op_ge.py
+++ b/backends/arm/operators/op_ge.py
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_get_item.py b/backends/arm/operators/op_get_item.py
index f7372262c6..577a8c8d2e 100644
--- a/backends/arm/operators/op_get_item.py
+++ b/backends/arm/operators/op_get_item.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py
index 7a22db6686..65cf8197bd 100644
--- a/backends/arm/operators/op_gt.py
+++ b/backends/arm/operators/op_gt.py
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py
index c971b50b66..fc0ee552a9 100644
--- a/backends/arm/operators/op_hardtanh.py
+++ b/backends/arm/operators/op_hardtanh.py
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py
index ee6929617e..8fea2b9208 100644
--- a/backends/arm/operators/op_le.py
+++ b/backends/arm/operators/op_le.py
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py
index 868eeb9443..7f664900b3 100644
--- a/backends/arm/operators/op_log.py
+++ b/backends/arm/operators/op_log.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py
index 20bac97af4..da93ab4179 100644
--- a/backends/arm/operators/op_lt.py
+++ b/backends/arm/operators/op_lt.py
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py
index 660a2cf0af..35a635de13 100644
--- a/backends/arm/operators/op_max.py
+++ b/backends/arm/operators/op_max.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -8,7 +8,7 @@
from typing import List
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py
index 6cb5f0490e..f32300f561 100644
--- a/backends/arm/operators/op_max_pool2d.py
+++ b/backends/arm/operators/op_max_pool2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py
index 2282d9e1cf..a409acf1ae 100644
--- a/backends/arm/operators/op_min.py
+++ b/backends/arm/operators/op_min.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -9,7 +9,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py
index c6a315d445..ef886de11e 100644
--- a/backends/arm/operators/op_mul.py
+++ b/backends/arm/operators/op_mul.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py
index 16d3d4a04e..103ae1b9a2 100644
--- a/backends/arm/operators/op_permute.py
+++ b/backends/arm/operators/op_permute.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py
index 121b78fed6..5410e1dd99 100644
--- a/backends/arm/operators/op_reciprocal.py
+++ b/backends/arm/operators/op_reciprocal.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py
index b5ffa2aa70..c37e4b3e75 100644
--- a/backends/arm/operators/op_relu.py
+++ b/backends/arm/operators/op_relu.py
@@ -5,7 +5,7 @@
# pyre-unsafe
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch.fx
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py
index fd76a52052..b97d7023ef 100644
--- a/backends/arm/operators/op_repeat.py
+++ b/backends/arm/operators/op_repeat.py
@@ -1,11 +1,11 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_rshift.py b/backends/arm/operators/op_rshift.py
index 2c1f4d5bbe..ac61cca6a9 100644
--- a/backends/arm/operators/op_rshift.py
+++ b/backends/arm/operators/op_rshift.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -6,7 +6,7 @@
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py
index 1cc3e8fcff..0fbb203b08 100644
--- a/backends/arm/operators/op_rsqrt.py
+++ b/backends/arm/operators/op_rsqrt.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py
index 0c28c0ed00..118c813dcf 100644
--- a/backends/arm/operators/op_sigmoid.py
+++ b/backends/arm/operators/op_sigmoid.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py
index 9327e005b6..7f4804af58 100644
--- a/backends/arm/operators/op_slice.py
+++ b/backends/arm/operators/op_slice.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py
index 0c569a6ffd..6cd422095a 100644
--- a/backends/arm/operators/op_sub.py
+++ b/backends/arm/operators/op_sub.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
@@ -75,7 +75,7 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
- tqutils.insert_rescale_op_to_int8(tosa_graph, sub_output, scale_back, node)
+ tqutils.insert_rescale_op_to_int8(tosa_graph, sub_output, scale_back, node) # type: ignore[possibly-undefined]
@register_node_visitor
diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py
index dcc194a656..b5b388b335 100644
--- a/backends/arm/operators/op_sum.py
+++ b/backends/arm/operators/op_sum.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_table.py
index bfaaf4578e..b411d8b91b 100644
--- a/backends/arm/operators/op_table.py
+++ b/backends/arm/operators/op_table.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -9,7 +9,7 @@
import numpy as np
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
@@ -21,7 +21,7 @@
@register_node_visitor
class TableVisitor(NodeVisitor):
- target = "_table"
+ target = "_table.default"
def define_node(
self,
@@ -30,9 +30,9 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
- assert node.name in self._exported_program.state_dict.keys()
+ assert node.name in self._exported_program.state_dict.keys() # type: ignore[union-attr]
assert inputs[0].dtype == output.dtype == ts.DType.INT8
- table = self._exported_program.state_dict[node.name]
+ table = self._exported_program.state_dict[node.name] # type: ignore[union-attr]
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(np.array(table))
tosa_graph.addOperator(
diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py
index a1e91be4ff..7961b14f2a 100644
--- a/backends/arm/operators/op_tanh.py
+++ b/backends/arm/operators/op_tanh.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
diff --git a/backends/arm/operators/op_to_copy.py b/backends/arm/operators/op_to_copy.py
index 256e54f3a2..feaec3a41e 100644
--- a/backends/arm/operators/op_to_copy.py
+++ b/backends/arm/operators/op_to_copy.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,9 +6,9 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
-import tosa.Op as TosaOp
+import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_to_dim_order_copy.py b/backends/arm/operators/op_to_dim_order_copy.py
index c2ec620b82..397979a439 100644
--- a/backends/arm/operators/op_to_dim_order_copy.py
+++ b/backends/arm/operators/op_to_dim_order_copy.py
@@ -6,9 +6,9 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
-import tosa.Op as TosaOp
+import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_transpose.py
index 42675be34b..54a79297dd 100644
--- a/backends/arm/operators/op_transpose.py
+++ b/backends/arm/operators/op_transpose.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
@@ -25,7 +25,7 @@ class TransposeVisitor(NodeVisitor):
Inserts a TOSA TRANSPOSE.
"""
- target = "_transpose"
+ target = "_transpose.default"
def define_node(
self,
diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py
index 68fcb521d9..38e4087d38 100644
--- a/backends/arm/operators/op_upsample_nearest2d.py
+++ b/backends/arm/operators/op_upsample_nearest2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,7 +6,7 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
@@ -16,7 +16,7 @@
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
from serializer.tosa_serializer import TosaOp
-from tosa.ResizeMode import ResizeMode
+from tosa.ResizeMode import ResizeMode # type: ignore
@register_node_visitor
diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py
index 3489795ed5..119e32fa58 100644
--- a/backends/arm/operators/op_view.py
+++ b/backends/arm/operators/op_view.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -6,9 +6,9 @@
# pyre-unsafe
from typing import List
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
-import tosa.Op as TosaOp
+import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py
index 36a1567df9..a83ead987e 100644
--- a/backends/arm/process_node.py
+++ b/backends/arm/process_node.py
@@ -8,7 +8,7 @@
from typing import cast, Dict
import numpy as np
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
@@ -36,9 +36,9 @@ def process_call_function(
# Visiting each Node
# pyre-ignore[16]: Undefined attribute.
- if node.target.__name__ in node_visitors:
+ if node.target.__name__ in node_visitors: # type: ignore[union-attr]
# pyre-ignore[16]: Undefined attribute.
- node_visitors[node.target.__name__].define_node(
+ node_visitors[node.target.__name__].define_node( # type: ignore[union-attr]
node,
tosa_graph,
inputs,
diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py
index cba66cfe56..c1a017fa1d 100644
--- a/backends/arm/quantizer/arm_quantizer.py
+++ b/backends/arm/quantizer/arm_quantizer.py
@@ -20,8 +20,12 @@
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
from executorch.backends.arm.quantizer import arm_quantizer_utils
-from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated
-from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph
+from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]
+ mark_node_as_annotated,
+)
+from executorch.backends.arm.quantizer.quantization_annotator import ( # type: ignore[import-not-found]
+ annotate_graph,
+)
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -253,7 +257,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
Currently transforms scalar values to tensor attributes.
"""
- return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
+ return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline( # type: ignore[arg-type]
graph_module=model
)
diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py
index f2a124f279..32f64963e8 100644
--- a/backends/arm/quantizer/quantization_annotator.py
+++ b/backends/arm/quantizer/quantization_annotator.py
@@ -55,7 +55,7 @@ def _is_ok_for_quantization(
for n_arg in _as_list(node.args[quant_property.index]):
assert isinstance(n_arg, Node)
- if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm):
+ if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
return False
return True
@@ -77,7 +77,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
assert isinstance(n_arg, Node)
_annotate_input_qspec_map(node, n_arg, qspec)
if quant_property.mark_annotated:
- arm_quantizer_utils.mark_node_as_annotated(n_arg)
+ arm_quantizer_utils.mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
def _annotate_output(node: Node, quant_property: _QuantProperty):
@@ -107,7 +107,7 @@ def _match_pattern(
child = next(iter(node.users))
elif node.target in pattern[1]:
assert len(node.args) != 0
- parent = node.args[0]
+ parent = node.args[0] # type: ignore[assignment]
child = node
else:
return False
@@ -132,6 +132,8 @@ def _match_pattern(
torch.ops.aten.sigmoid.default,
torch.ops.aten.tanh.default,
torch.ops.aten.sum.dim_IntList,
+ torch.ops.aten.hardsigmoid.default,
+ torch.ops.aten.hardswish.default,
]
_one_to_one_shared_input_qspec = [
@@ -186,6 +188,8 @@ def _match_pattern(
torch.ops.aten.full.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
+ torch.ops.aten.clamp.default,
+ torch.ops.aten.clamp.Tensor,
operator.getitem,
]
@@ -259,23 +263,23 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.minimum.default,
torch.ops.aten.maximum.default,
):
- shared_qspec = SharedQuantizationSpec((node.args[0], node))
+ shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
- 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
+ 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
),
]
- quant_properties.quant_output = _QuantProperty(0, shared_qspec)
+ quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
input_qspec = (
- SharedQuantizationSpec(node.args[0])
- if arm_quantizer_utils.is_output_annotated(node.args[0])
+ SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
+ if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore
else input_act_qspec
)
- quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
+ quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(
- 0, SharedQuantizationSpec((node.args[0], node))
+ 0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
)
elif node.target in (
torch.ops.aten.cat.default,
@@ -290,19 +294,19 @@ def any_or_hardtanh_min_zero(n: Node):
_QuantProperty(
0,
[
- input_act_qspec if n == node.args[0][0] else shared_qspec
+ input_act_qspec if n == node.args[0][0] else shared_qspec # type: ignore[misc]
for n in node.args[0]
],
)
]
- quant_properties.quant_output = _QuantProperty(0, shared_qspec)
+ quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in _one_to_one_shared_input_qspec:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(
- 0, SharedQuantizationSpec((node.args[0], node))
+ 0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
)
elif node.target in [
torch.ops.aten.eq.Tensor,
@@ -311,26 +315,26 @@ def any_or_hardtanh_min_zero(n: Node):
torch.ops.aten.le.Tensor,
torch.ops.aten.lt.Tensor,
]:
- shared_qspec = SharedQuantizationSpec((node.args[0], node))
+ shared_qspec = SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
- 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
+ 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
),
]
quant_properties.quant_output = None
elif node.target in _parent_shared_qspec:
if not isinstance(node.args[0], Node):
- return None
+ return None # type: ignore[return-value]
- if not arm_quantizer_utils.is_output_annotated(node.args[0]):
- return None
+ if not arm_quantizer_utils.is_output_annotated(node.args[0]): # type: ignore[attr-defined]
+ return None # type: ignore[return-value]
shared_qspec = SharedQuantizationSpec(node.args[0])
- quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
- quant_properties.quant_output = _QuantProperty(0, shared_qspec)
+ quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
+ quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
else:
- return None
+ return None # type: ignore[return-value]
# Don't check if operator.getitem is ok for quantization, it's always ok
if node.target == operator.getitem:
@@ -340,16 +344,16 @@ def any_or_hardtanh_min_zero(n: Node):
# provided QuantProperties
for quant_property in quant_properties.quant_inputs:
if not _is_ok_for_quantization(node, quant_property, gm):
- return None
+ return None # type: ignore[return-value]
if quant_properties.quant_output is not None:
if not _is_ok_for_quantization(node, quant_properties.quant_output, gm):
- return None
+ return None # type: ignore[return-value]
return quant_properties
-def annotate_graph(
+def annotate_graph( # type: ignore[return]
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
@@ -374,4 +378,4 @@ def annotate_graph(
if quant_properties.quant_output is not None:
_annotate_output(node, quant_properties.quant_output)
- arm_quantizer_utils.mark_node_as_annotated(node)
+ arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]
diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py
index b94d9bda64..394995201e 100644
--- a/backends/arm/quantizer/quantization_config.py
+++ b/backends/arm/quantizer/quantization_config.py
@@ -82,14 +82,14 @@ def _derive_qparams_fn(
input_act = node.args[0]
weight = node.args[1]
quantization_spec = DerivedQuantizationSpec(
- derived_from=[(input_act, node), (weight, node)],
+ derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item]
derive_qparams_fn=_derive_qparams_fn,
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max - 1,
qscheme=torch.per_tensor_symmetric,
)
- return quantization_spec
+ return quantization_spec # type: ignore[return-value]
if self.bias is None:
return None
diff --git a/backends/arm/scripts/pre-commit b/backends/arm/scripts/pre-commit
new file mode 100755
index 0000000000..2000585f93
--- /dev/null
+++ b/backends/arm/scripts/pre-commit
@@ -0,0 +1,13 @@
+#!/bin/bash
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Check 1: If commit header contains WIP, everything is ok
+git rev-list --format=%s --max-count=1 HEAD | grep -q WIP && exit 0
+
+# Check 2: lintunner on latest patch.
+lintrunner -a --revision 'HEAD^' --skip MYPY
+commit_files=$(git diff-tree --no-commit-id --name-only --diff-filter=M HEAD -r)
+git add $commit_files || true
\ No newline at end of file
diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push
new file mode 100755
index 0000000000..c51138b8ec
--- /dev/null
+++ b/backends/arm/scripts/pre-push
@@ -0,0 +1,45 @@
+#!/bin/bash
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Check 1: If commit header contains WIP, everything is ok
+git rev-list --format=%s --max-count=1 HEAD | grep -q WIP && exit 0
+
+# Check 2: lintunner on latest patches.
+lintrunner --revision 'HEAD^'
+if [[ $? != 0 ]]
+ then
+ echo "Failed linting"
+ exit 1
+fi
+
+# Check 3: License headers
+# We do a simple check of if all committed headers contain "$current_year Arm".
+# This does not guarantee OK in ci but should be ok most of the time.
+
+current_year=$(date +%Y)
+failed_license_check=false
+commit_files=$(git diff-tree --no-commit-id --name-only --diff-filter=ACMR HEAD -r)
+
+
+for commited_file in $commit_files; do
+ head $commited_file | grep -q "$current_year Arm"
+ if [[ $? != 0 ]]
+ then
+ echo "Header in $commited_file did not contain '$current_year Arm'"
+ failed_license_check=true
+ else
+ echo "$commited_file passed license check"
+ fi
+done
+
+if [[ $failed_license_check == true ]]
+ then
+ exit 1
+ else
+ echo "Passed simple license check"
+fi
+
+exit 0
diff --git a/backends/arm/scripts/setup-dev-env.sh b/backends/arm/scripts/setup-dev-env.sh
new file mode 100755
index 0000000000..b8c9b3b44c
--- /dev/null
+++ b/backends/arm/scripts/setup-dev-env.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+git_dir=$(git rev-parse --git-dir)
+ln $git_dir/../backends/arm/scripts/pre-push $git_dir/hooks
+ln $git_dir/../backends/arm/scripts/pre-commit $git_dir/hooks
\ No newline at end of file
diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py
index 7ebf89e392..091b2d5f26 100644
--- a/backends/arm/test/common.py
+++ b/backends/arm/test/common.py
@@ -9,9 +9,17 @@
import tempfile
from datetime import datetime
+
from pathlib import Path
+from typing import Any
+import pytest
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
+from executorch.backends.arm.test.runner_utils import (
+ arm_executor_runner_exists,
+ corstone300_installed,
+ corstone320_installed,
+)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -41,8 +49,8 @@ def maybe_get_tosa_collate_path() -> str | None:
if tosa_test_base:
current_test = os.environ.get("PYTEST_CURRENT_TEST")
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
- test_class = current_test.split("::")[1]
- test_name = current_test.split("::")[-1].split(" ")[0]
+ test_class = current_test.split("::")[1] # type: ignore[union-attr]
+ test_name = current_test.split("::")[-1].split(" ")[0] # type: ignore[union-attr]
if "BI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
elif "MI" in test_name:
@@ -100,7 +108,7 @@ def get_u85_compile_spec(
"""
Default compile spec for Ethos-U85 tests.
"""
- return get_u85_compile_spec_unbuilt(
+ return get_u85_compile_spec_unbuilt( # type: ignore[attr-defined]
custom_path=custom_path,
).build()
@@ -144,4 +152,45 @@ def get_u85_compile_spec_unbuilt(
)
.dump_intermediate_artifacts_to(artifact_path)
)
- return compile_spec
+ return compile_spec # type: ignore[return-value]
+
+
+SkipIfNoCorstone300 = pytest.mark.skipif(
+ not corstone300_installed() or not arm_executor_runner_exists("corstone-300"),
+ reason="Did not find Corstone-300 FVP or executor_runner on path",
+)
+"""Skips a test if Corsone300 FVP is not installed, or if the executor runner is not built"""
+
+SkipIfNoCorstone320 = pytest.mark.skipif(
+ not corstone320_installed() or not arm_executor_runner_exists("corstone-320"),
+ reason="Did not find Corstone-320 FVP or executor_runner on path",
+)
+"""Skips a test if Corsone320 FVP is not installed, or if the executor runner is not built."""
+
+
+def parametrize(
+ arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None
+):
+ """
+ Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
+ - test_data is expected as a dict of (id, test_data) pairs
+ - alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail
+ """
+ if xfails is None:
+ xfails = {}
+
+ def decorator_func(func):
+ """Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function"""
+ pytest_testsuite = []
+ for id, test_parameters in test_data.items():
+ if id in xfails:
+ pytest_param = pytest.param(
+ test_parameters, id=id, marks=pytest.mark.xfail(reason=xfails[id])
+ )
+ else:
+ pytest_param = pytest.param(test_parameters, id=id)
+ pytest_testsuite.append(pytest_param)
+
+ return pytest.mark.parametrize(arg_name, pytest_testsuite)(func)
+
+ return decorator_func
diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py
index a9491418a4..690549d717 100644
--- a/backends/arm/test/misc/test_debug_feats.py
+++ b/backends/arm/test/misc/test_debug_feats.py
@@ -48,7 +48,7 @@ def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None):
(
ArmTester(
module,
- example_inputs=module.get_inputs(),
+ example_inputs=module.get_inputs(), # type: ignore[operator]
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
@@ -61,7 +61,7 @@ def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None):
(
ArmTester(
module,
- example_inputs=module.get_inputs(),
+ example_inputs=module.get_inputs(), # type: ignore[operator]
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py
index a16b1e639b..092483fd63 100644
--- a/backends/arm/test/misc/test_lifted_tensor.py
+++ b/backends/arm/test/misc/test_lifted_tensor.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
-from parameterized import parameterized
+from parameterized import parameterized # type: ignore[import-untyped]
class LiftedTensor(torch.nn.Module):
@@ -23,14 +23,14 @@ class LiftedTensor(torch.nn.Module):
(operator.sub, (torch.rand(2, 2), 2)),
]
- def __init__(self, op: callable):
+ def __init__(self, op: callable): # type: ignore[valid-type]
super().__init__()
self.op = op
self.lifted_tensor = torch.Tensor([[1, 2], [3, 4]])
def forward(self, x: torch.Tensor, length) -> torch.Tensor:
sliced = self.lifted_tensor[:, :length]
- return self.op(sliced, x)
+ return self.op(sliced, x) # type: ignore[misc]
class LiftedScalarTensor(torch.nn.Module):
@@ -42,13 +42,13 @@ class LiftedScalarTensor(torch.nn.Module):
(operator.sub, (torch.randn(3),), 1.0),
]
- def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]):
+ def __init__(self, op: callable, arg1: Union[int, float, torch.tensor]): # type: ignore[valid-type]
super().__init__()
self.op = op
self.arg1 = arg1
def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.op(x, self.arg1)
+ return self.op(x, self.arg1) # type: ignore[misc]
class TestLiftedTensor(unittest.TestCase):
diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py
index 77b10cf315..d61b3fe718 100644
--- a/backends/arm/test/misc/test_tosa_spec.py
+++ b/backends/arm/test/misc/test_tosa_spec.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -13,7 +13,7 @@
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
-from parameterized import parameterized
+from parameterized import parameterized # type: ignore[import-untyped]
test_valid_0_80_strings = [
"TOSA-0.80+BI",
@@ -64,13 +64,13 @@
class TestTosaSpecification(unittest.TestCase):
"""Tests the TOSA specification class"""
- @parameterized.expand(test_valid_0_80_strings)
+ @parameterized.expand(test_valid_0_80_strings) # type: ignore[misc]
def test_version_string_0_80(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_0_80)
assert tosa_spec.profile in ["BI", "MI"]
- @parameterized.expand(test_valid_1_00_strings)
+ @parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
def test_version_string_1_00(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
assert isinstance(tosa_spec, Tosa_1_00)
@@ -83,7 +83,7 @@ def test_version_string_1_00(self, version_string: str):
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
]
- @parameterized.expand(test_invalid_strings)
+ @parameterized.expand(test_invalid_strings) # type: ignore[misc]
def test_invalid_version_strings(self, version_string: str):
tosa_spec = None
with self.assertRaises(ValueError):
@@ -91,12 +91,12 @@ def test_invalid_version_strings(self, version_string: str):
assert tosa_spec is None
- @parameterized.expand(test_compile_specs)
+ @parameterized.expand(test_compile_specs) # type: ignore[misc]
def test_create_from_compilespec(self, compile_specs: list[CompileSpec]):
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
assert isinstance(tosa_spec, TosaSpecification)
- @parameterized.expand(test_compile_specs_no_version)
+ @parameterized.expand(test_compile_specs_no_version) # type: ignore[misc]
def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]):
tosa_spec = None
with self.assertRaises(ValueError):
diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py
new file mode 100644
index 0000000000..e3be7811dd
--- /dev/null
+++ b/backends/arm/test/models/test_conformer.py
@@ -0,0 +1,126 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import unittest
+
+import torch
+from executorch.backends.arm.test import common, conftest
+
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+
+from torchaudio.models import Conformer
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class TestConformer(unittest.TestCase):
+ """Tests Torchaudio Conformer"""
+
+ # Adjust nbr below as we increase op support. Note: most of the delegates
+ # calls are directly consecutive to each other in the .pte. The reason
+ # for that is some assert ops are removed by passes in the
+ # .to_executorch step, i.e. after Arm partitioner.
+ ops_after_partitioner = {
+ "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
+ "executorch_exir_dialects_edge__ops_aten_full_like_default": 4,
+ "executorch_exir_dialects_edge__ops_aten_max_default": 1,
+ "executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
+ "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
+ "executorch_exir_dialects_edge__ops_aten_where_self": 4,
+ "executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
+ "executorch_exir_dialects_edge__ops_aten_any_dim": 2,
+ "torch.ops.aten._assert_scalar.default": 10,
+ "torch.ops.aten._local_scalar_dense.default": 1,
+ "torch.ops.aten.scalar_tensor.default": 2,
+ "torch.ops.higher_order.executorch_call_delegate": 5,
+ }
+
+ dim = 16
+ lengths = torch.randint(1, 100, (10,), dtype=torch.int32)
+ input_data = torch.rand(10, int(lengths.max()), dim)
+ conformer = Conformer(
+ input_dim=dim,
+ num_heads=4,
+ ffn_dim=64,
+ num_layers=2,
+ depthwise_conv_kernel_size=31,
+ )
+ conformer = conformer.eval()
+
+ def test_conformer_tosa_MI(self):
+ (
+ ArmTester(
+ self.conformer,
+ example_inputs=(self.input_data, self.lengths),
+ compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+MI"),
+ )
+ .export()
+ .to_edge_transform_and_lower()
+ .dump_operator_distribution()
+ .check_count(self.ops_after_partitioner)
+ .to_executorch()
+ # TODO(MLETORCH-632): Fix numerical errors
+ .run_method_and_compare_outputs(
+ inputs=(self.input_data, self.lengths), rtol=1, atol=5
+ )
+ )
+
+ @unittest.expectedFailure # TODO(MLETORCH-635)
+ def test_conformer_tosa_BI(self):
+ (
+ ArmTester(
+ self.conformer,
+ example_inputs=(self.input_data, self.lengths),
+ compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-0.80+BI"),
+ )
+ .quantize()
+ .export()
+ .to_edge_transform_and_lower()
+ .to_executorch()
+ .run_method_and_compare_outputs(
+ qtol=1, rtol=1, atol=5, inputs=(self.input_data, self.lengths)
+ )
+ )
+
+ @unittest.expectedFailure # TODO(MLETORCH-635)
+ def test_conformer_u55_BI(self):
+ tester = (
+ ArmTester(
+ self.conformer,
+ example_inputs=(self.input_data, self.lengths),
+ compile_spec=common.get_u55_compile_spec(),
+ )
+ .quantize()
+ .export()
+ .to_edge_transform_and_lower()
+ .to_executorch()
+ .serialize()
+ )
+ if conftest.is_option_enabled("corstone_fvp"):
+ tester.run_method_and_compare_outputs(
+ atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
+ )
+
+ @unittest.expectedFailure # TODO(MLETORCH-635)
+ def test_conformer_u85_BI(self):
+ tester = (
+ ArmTester(
+ self.conformer,
+ example_inputs=(self.input_data, self.lengths),
+ compile_spec=common.get_u85_compile_spec(),
+ )
+ .quantize()
+ .export()
+ .to_edge_transform_and_lower()
+ .to_executorch()
+ .serialize()
+ )
+ if conftest.is_option_enabled("corstone_fvp"):
+ tester.run_method_and_compare_outputs(
+ atol=1.0, qtol=1, inputs=(self.input_data, self.lengths)
+ )
diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py
index 21bd43202d..62b14a1022 100644
--- a/backends/arm/test/models/test_mobilenet_v2_arm.py
+++ b/backends/arm/test/models/test_mobilenet_v2_arm.py
@@ -14,8 +14,10 @@
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
-from torchvision import models, transforms
-from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
+from torchvision import models, transforms # type: ignore[import-untyped]
+from torchvision.models.mobilenetv2 import ( # type: ignore[import-untyped]
+ MobileNet_V2_Weights,
+)
logger = logging.getLogger(__name__)
diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py
index db6fde53ae..b4b43f88c7 100644
--- a/backends/arm/test/ops/test_add.py
+++ b/backends/arm/test/ops/test_add.py
@@ -5,169 +5,143 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-import unittest
from typing import Tuple
-import pytest
import torch
-from executorch.backends.arm.test import common, conftest
-from executorch.backends.arm.test.tester.arm_tester import ArmTester
-from executorch.exir.backend.compile_spec_schema import CompileSpec
-from parameterized import parameterized
-
-
-class TestSimpleAdd(unittest.TestCase):
- """Tests a single add op, x+x and x+y."""
-
- class Add(torch.nn.Module):
- test_parameters = [
- (torch.FloatTensor([1, 2, 3, 5, 7]),),
- (3 * torch.ones(8),),
- (10 * torch.randn(8),),
- (torch.ones(1, 1, 4, 4),),
- (torch.ones(1, 3, 4, 2),),
- ]
-
- def forward(self, x):
- return x + x
-
- class Add2(torch.nn.Module):
- test_parameters = [
- (
- torch.FloatTensor([1, 2, 3, 5, 7]),
- (torch.FloatTensor([2, 1, 2, 1, 10])),
- ),
- (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
- (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
- (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
- (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
- ]
-
- def __init__(self):
- super().__init__()
-
- def forward(self, x, y):
- return x + y
-
- def _test_add_tosa_MI_pipeline(
- self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
- ):
- (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
- )
- .export()
- .check_count({"torch.ops.aten.add.Tensor": 1})
- .check_not(["torch.ops.quantized_decomposed"])
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .to_executorch()
- .run_method_and_compare_outputs(inputs=test_data)
- )
-
- def _test_add_tosa_BI_pipeline(
- self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
- ):
- (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
- )
- .quantize()
- .export()
- .check_count({"torch.ops.aten.add.Tensor": 1})
- .check(["torch.ops.quantized_decomposed"])
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .to_executorch()
- .run_method_and_compare_outputs(inputs=test_data, qtol=1)
- )
-
- def _test_add_ethos_BI_pipeline(
- self,
- module: torch.nn.Module,
- compile_spec: CompileSpec,
- test_data: Tuple[torch.Tensor],
- ):
- tester = (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=compile_spec,
- )
- .quantize()
- .export()
- .check_count({"torch.ops.aten.add.Tensor": 1})
- .check(["torch.ops.quantized_decomposed"])
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .to_executorch()
- .serialize()
- )
- if conftest.is_option_enabled("corstone_fvp"):
- tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
-
- return tester
-
- @parameterized.expand(Add.test_parameters)
- def test_add_tosa_MI(self, test_data: torch.Tensor):
- test_data = (test_data,)
- self._test_add_tosa_MI_pipeline(self.Add(), test_data)
-
- @parameterized.expand(Add.test_parameters)
- def test_add_tosa_BI(self, test_data: torch.Tensor):
- test_data = (test_data,)
- self._test_add_tosa_BI_pipeline(self.Add(), test_data)
-
- @parameterized.expand(Add.test_parameters)
- @pytest.mark.corstone_fvp
- def test_add_u55_BI(self, test_data: torch.Tensor):
- test_data = (test_data,)
- self._test_add_ethos_BI_pipeline(
- self.Add(),
- common.get_u55_compile_spec(),
- test_data,
- )
-
- @parameterized.expand(Add.test_parameters)
- @pytest.mark.corstone_fvp
- def test_add_u85_BI(self, test_data: torch.Tensor):
- test_data = (test_data,)
- self._test_add_ethos_BI_pipeline(
- self.Add(),
- common.get_u85_compile_spec(),
- test_data,
- )
-
- @parameterized.expand(Add2.test_parameters)
- def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
- self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
-
- @parameterized.expand(Add2.test_parameters)
- def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
- self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
-
- @parameterized.expand(Add2.test_parameters)
- @pytest.mark.corstone_fvp
- def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
- self._test_add_ethos_BI_pipeline(
- self.Add2(), common.get_u55_compile_spec(), test_data
- )
-
- @parameterized.expand(Add2.test_parameters)
- @pytest.mark.corstone_fvp
- def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
- self._test_add_ethos_BI_pipeline(
- self.Add2(), common.get_u85_compile_spec(), test_data
- )
+from executorch.backends.arm.test import common
+from executorch.backends.arm.test.tester.test_pipeline import (
+ EthosU55PipelineBI,
+ EthosU85PipelineBI,
+ TosaPipelineBI,
+ TosaPipelineMI,
+)
+
+aten_op = "torch.ops.aten.add.Tensor"
+exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
+
+input_t1 = Tuple[torch.Tensor] # Input x
+
+
+class Add(torch.nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x + x
+
+ test_data: list[input_t1] = {
+ "5d_float": (torch.FloatTensor([1, 2, 3, 5, 7]),),
+ "1d_ones": ((3 * torch.ones(8),)),
+ "1d_randn": (10 * torch.randn(8),),
+ "4d_ones_1": (torch.ones(1, 1, 4, 4),),
+ "4d_ones_2": (torch.ones(1, 3, 4, 2),),
+ }
+
+
+input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
+
+
+class Add2(torch.nn.Module):
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ return x + y
+
+ test_data: list[input_t2] = {
+ "5d_float": (
+ torch.FloatTensor([1, 2, 3, 5, 7]),
+ (torch.FloatTensor([2, 1, 2, 1, 10])),
+ ),
+ "4d_ones": (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
+ "4d_randn_1": (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
+ "4d_randn_2": (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
+ "4d_randn_big": (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
+ }
+
+
+@common.parametrize("test_data", Add.test_data)
+def test_add_tosa_MI(test_data: input_t1):
+ pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add.test_data)
+def test_add_tosa_BI(test_data: input_t1):
+ pipeline = TosaPipelineBI[input_t1](Add(), test_data, aten_op, exir_op)
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add.test_data)
+def test_add_u55_BI(test_data: input_t1):
+ pipeline = EthosU55PipelineBI[input_t1](
+ Add(), test_data, aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add.test_data)
+def test_add_u85_BI(test_data: input_t1):
+ pipeline = EthosU85PipelineBI[input_t1](
+ Add(), test_data, aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add.test_data)
+@common.SkipIfNoCorstone300
+def test_add_u55_BI_on_fvp(test_data: input_t1):
+ pipeline = EthosU55PipelineBI[input_t1](
+ Add(), test_data, aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add.test_data)
+@common.SkipIfNoCorstone320
+def test_add_u85_BI_on_fvp(test_data: input_t1):
+ pipeline = EthosU85PipelineBI[input_t1](
+ Add(), test_data, aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+def test_add2_tosa_MI(test_data: input_t2):
+ pipeline = TosaPipelineMI[input_t2](Add2(), test_data, aten_op, exir_op)
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+def test_add2_tosa_BI(test_data: input_t2):
+ pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+def test_add2_u55_BI(test_data: input_t2):
+ pipeline = EthosU55PipelineBI[input_t2](
+ Add2(), test_data, aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+@common.SkipIfNoCorstone300
+def test_add2_u55_BI_on_fvp(test_data: input_t2):
+ pipeline = EthosU55PipelineBI[input_t2](
+ Add2(), test_data, aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+def test_add2_u85_BI(test_data: input_t2):
+ pipeline = EthosU85PipelineBI[input_t2](
+ Add2(), test_data, aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_data", Add2.test_data)
+@common.SkipIfNoCorstone320
+def test_add2_u85_BI_on_fvp(test_data: input_t2):
+ pipeline = EthosU85PipelineBI[input_t2](
+ Add2(), test_data, aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py
index 06470d91e8..bd6e1ef689 100644
--- a/backends/arm/test/ops/test_bmm.py
+++ b/backends/arm/test/ops/test_bmm.py
@@ -6,7 +6,7 @@
import unittest
-from typing import Tuple
+from typing import Callable, Tuple
import pytest
@@ -16,39 +16,37 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized
-torch.manual_seed(1)
-
class TestBMM(unittest.TestCase):
"""Tests Batch MatMul"""
class BMM(torch.nn.Module):
- test_parameters = [
- (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
- (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
- (torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
- (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
- (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
+ test_data_generators = [
+ lambda: (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
+ lambda: (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
+ lambda: (torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
+ lambda: (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
+ lambda: (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
]
def forward(self, x, y):
return torch.bmm(x, y)
class MatMul(torch.nn.Module):
- test_parameters = [
- (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
- (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
+ test_data_generators = [
+ lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
+ lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
]
def forward(self, x, y):
return torch.matmul(x, y)
class BMMSingleInput(torch.nn.Module):
- test_parameters = [
- (torch.rand(20, 3, 3),),
- (torch.rand(2, 128, 128),),
- (10000 * torch.randn(4, 25, 25),),
- (5 + 5 * torch.randn(3, 64, 64),),
+ test_data_generators = [
+ lambda: (torch.rand(20, 3, 3),),
+ lambda: (torch.rand(2, 128, 128),),
+ lambda: (10000 * torch.randn(4, 25, 25),),
+ lambda: (5 + 5 * torch.randn(3, 64, 64),),
]
def forward(self, x):
@@ -120,67 +118,69 @@ def _test_bmm_ethosu_BI_pipeline(
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(inputs=test_data, qtol=1)
- @parameterized.expand(BMM.test_parameters)
- def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(BMM.test_data_generators)
+ def test_bmm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
- @parameterized.expand(BMMSingleInput.test_parameters)
- def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ @parameterized.expand(BMMSingleInput.test_data_generators)
+ def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
- @parameterized.expand(MatMul.test_parameters)
- def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(MatMul.test_data_generators)
+ def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
- @parameterized.expand(MatMul.test_parameters)
- def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(MatMul.test_data_generators)
+ def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
- @parameterized.expand(BMM.test_parameters)
- def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(BMM.test_data_generators)
+ def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
- @parameterized.expand(BMMSingleInput.test_parameters)
- def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ @parameterized.expand(BMMSingleInput.test_data_generators)
+ def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
- @parameterized.expand(BMM.test_parameters)
+ @parameterized.expand(BMM.test_data_generators)
@pytest.mark.corstone_fvp
@unittest.expectedFailure
- def test_bmm_u55_BI_xfails(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
self.BMM(), common.get_u55_compile_spec(), test_data
)
- @parameterized.expand(BMM.test_parameters)
+ @parameterized.expand(BMM.test_data_generators)
@pytest.mark.corstone_fvp
- def test_bmm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
self.BMM(), common.get_u85_compile_spec(), test_data
)
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
- @parameterized.expand(BMMSingleInput.test_parameters)
+ @parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.corstone_fvp
@unittest.expectedFailure
- def test_bmm_single_input_u55_BI_xfails(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ def test_bmm_single_input_u55_BI_xfails(
+ self, test_data_generator: Callable[[], Tuple]
+ ):
+ test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
self.BMMSingleInput(), common.get_u55_compile_spec(), test_data
)
- @parameterized.expand(BMMSingleInput.test_parameters)
+ @parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.corstone_fvp
- def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
self.BMMSingleInput(), common.get_u85_compile_spec(), test_data
)
diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py
index 115b4402f5..a1613d1d04 100644
--- a/backends/arm/test/ops/test_cat.py
+++ b/backends/arm/test/ops/test_cat.py
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -33,6 +33,8 @@ class Cat(torch.nn.Module):
),
-1,
),
+ ((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 1)), 3),
+ ((torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 4)), 0),
((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3),
(
(
@@ -47,8 +49,8 @@ class Cat(torch.nn.Module):
def __init__(self):
super().__init__()
- def forward(self, tensors: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
- return torch.cat(tensors, dim=dim)
+ def forward(self, t: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
+ return torch.cat(t, dim=dim)
def _test_cat_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
@@ -134,22 +136,38 @@ def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_tosa_BI_pipeline(self.Cat(), test_data)
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
- @parameterized.expand(Cat.test_parameters)
+ @parameterized.expand(Cat.test_parameters[:-3])
@pytest.mark.corstone_fvp
- @conftest.expectedFailureOnFVP
def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u55_compile_spec(), test_data
)
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
- @parameterized.expand(Cat.test_parameters)
+ # MLETORCH-630 Cat does not work on FVP with batch>1
+ @parameterized.expand(Cat.test_parameters[-3:])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
+ def test_cat_u55_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
+ test_data = (operands, dim)
+ self._test_cat_ethosu_BI_pipeline(
+ self.Cat(), common.get_u55_compile_spec(), test_data
+ )
+
+ @parameterized.expand(Cat.test_parameters[:-3])
+ @pytest.mark.corstone_fvp
def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
test_data = (operands, dim)
self._test_cat_ethosu_BI_pipeline(
self.Cat(), common.get_u85_compile_spec(), test_data
)
+
+ # MLETORCH-630 Cat does not work on FVP with batch>1
+ @parameterized.expand(Cat.test_parameters[-3:])
+ @pytest.mark.corstone_fvp
+ @conftest.expectedFailureOnFVP
+ def test_cat_u85_BI_xfails(self, operands: tuple[torch.Tensor, ...], dim: int):
+ test_data = (operands, dim)
+ self._test_cat_ethosu_BI_pipeline(
+ self.Cat(), common.get_u85_compile_spec(), test_data
+ )
diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py
new file mode 100644
index 0000000000..5cf333068c
--- /dev/null
+++ b/backends/arm/test/ops/test_clamp.py
@@ -0,0 +1,165 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+from numbers import Number
+from typing import Tuple, Union
+
+import pytest
+import torch
+
+from executorch.backends.arm.quantizer.arm_quantizer import (
+ ArmQuantizer,
+ get_symmetric_quantization_config,
+)
+from executorch.backends.arm.test import common, conftest
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+from executorch.backends.arm.tosa_specification import TosaSpecification
+from executorch.backends.xnnpack.test.tester.tester import Quantize
+from executorch.exir.backend.compile_spec_schema import CompileSpec
+from parameterized import parameterized
+
+
+test_data_suite = [
+ # (test_name, test_data, min, max)
+ ("rank_1", torch.rand(10) * 2, -1.0, 1.0),
+ ("rank_2", torch.rand(1, 35), 0.5, 0.8),
+ ("rank_3", torch.ones(1, 10, 10), -1, -1),
+ ("rank_4", torch.rand(1, 10, 10, 1) * 2, -0.1, 2.0),
+ ("rank_4_mixed_min_max_dtype", torch.rand(1, 10, 10, 5) + 10, 8.0, 10),
+ ("rank_4_no_min", torch.rand(1, 10, 10, 1) * 10, None, 5),
+ ("rank_4_no_max", torch.rand(1, 10, 10, 1) - 3, -3.3, None),
+]
+
+
+class TestClamp(unittest.TestCase):
+ """Tests Clamp Operator."""
+
+ class Clamp(torch.nn.Module):
+ def __init__(
+ self,
+ min: Union[torch.Tensor, Number, None],
+ max: Union[torch.Tensor, Number, None],
+ ):
+ super().__init__()
+
+ self.clamp_min = min
+ self.clamp_max = max
+
+ def forward(self, x):
+ return torch.clamp(x, self.clamp_min, self.clamp_max)
+
+ def _test_clamp_tosa_MI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
+ ):
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
+ )
+ .export()
+ .check(["torch.ops.aten.clamp.default"])
+ .check_not(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_clamp_tosa_BI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
+ ):
+ tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI")
+ compile_spec = common.get_tosa_compile_spec(tosa_spec)
+ quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=compile_spec,
+ )
+ .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
+ .export()
+ .check_count({"torch.ops.aten.clamp.default": 1})
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_clamp_tosa_ethos_BI_pipeline(
+ self,
+ compile_spec: list[CompileSpec],
+ module: torch.nn.Module,
+ test_data: Tuple[torch.tensor],
+ ):
+ tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
+ quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())
+ tester = (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=compile_spec,
+ )
+ .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
+ .export()
+ .check_count({"torch.ops.aten.clamp.default": 1})
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .serialize()
+ )
+ if conftest.is_option_enabled("corstone_fvp"):
+ tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
+
+ @parameterized.expand(test_data_suite)
+ def test_clamp_tosa_MI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ min: Union[torch.Tensor, Number, None],
+ max: Union[torch.Tensor, Number, None],
+ ):
+ self._test_clamp_tosa_MI_pipeline(self.Clamp(min, max), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ def test_clamp_tosa_BI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ min: Union[torch.Tensor, Number, None],
+ max: Union[torch.Tensor, Number, None],
+ ):
+ self._test_clamp_tosa_BI_pipeline(self.Clamp(min, max), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_clamp_tosa_u55_BI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ min: Union[torch.Tensor, Number, None],
+ max: Union[torch.Tensor, Number, None],
+ ):
+ self._test_clamp_tosa_ethos_BI_pipeline(
+ common.get_u55_compile_spec(), self.Clamp(min, max), (test_data,)
+ )
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_clamp_tosa_u85_BI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ min: Union[torch.Tensor, Number, None],
+ max: Union[torch.Tensor, Number, None],
+ ):
+ self._test_clamp_tosa_ethos_BI_pipeline(
+ common.get_u85_compile_spec(), self.Clamp(min, max), (test_data,)
+ )
diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py
index 3e0dfa6c5c..92da09a5ef 100644
--- a/backends/arm/test/ops/test_conv1d.py
+++ b/backends/arm/test/ops/test_conv1d.py
@@ -6,7 +6,7 @@
import unittest
-from typing import List, Optional, Tuple, Union
+from typing import List, Tuple, Union
import pytest
@@ -25,7 +25,6 @@ class Conv1d(torch.nn.Module):
def __init__(
self,
- inputs: Optional[torch.Tensor] = None,
length=8,
nbr_conv=1, # Number of chained convs
in_channels: Union[List, int, None] = None,
@@ -75,11 +74,10 @@ def __init__(
if not isinstance(padding_mode, List):
padding_mode = [padding_mode]
- # Generate test data if not provided
- if inputs is None:
- self.inputs = (torch.randn(batches, in_channels[0], length).to(dtype),)
- else:
- self.inputs = (inputs,)
+ self.batches = batches
+ self.in_channels = in_channels
+ self.length = length
+ self.dtype = dtype
# Build chain of convs
for i in range(self.nbr_convs):
@@ -100,7 +98,9 @@ def __init__(
)
def get_inputs(self):
- return self.inputs
+ return (
+ torch.randn(self.batches, self.in_channels[0], self.length).to(self.dtype),
+ )
def forward(self, x):
for i in range(self.nbr_convs):
diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py
index b80228c6f2..878c65757f 100644
--- a/backends/arm/test/ops/test_conv2d.py
+++ b/backends/arm/test/ops/test_conv2d.py
@@ -4,17 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-import unittest
-from typing import List, Optional, Tuple, Union
-
-import pytest
+from typing import List, Tuple, Union
import torch
-from executorch.backends.arm.test import common, conftest
-from executorch.backends.arm.test.tester.arm_tester import ArmTester
-from executorch.exir.backend.compile_spec_schema import CompileSpec
-from parameterized import parameterized
+from executorch.backends.arm.test import common
+from executorch.backends.arm.test.tester.test_pipeline import (
+ EthosU55PipelineBI,
+ EthosU85PipelineBI,
+ TosaPipelineBI,
+ TosaPipelineMI,
+)
+
+aten_op = "torch.ops.aten.conv2d.default"
+exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default"
class Conv2d(torch.nn.Module):
@@ -25,7 +28,6 @@ class Conv2d(torch.nn.Module):
def __init__(
self,
- inputs: Optional[torch.Tensor] = None,
height=8,
width=8,
nbr_conv=1, # Number of chained convs
@@ -76,13 +78,11 @@ def __init__(
if not isinstance(padding_mode, List):
padding_mode = [padding_mode]
- # Generate test data if not provided
- if inputs is None:
- self.inputs = (
- torch.randn(batches, in_channels[0], height, width).to(dtype),
- )
- else:
- self.inputs = (inputs,)
+ self.batches = batches
+ self.in_channels = in_channels
+ self.height = height
+ self.width = width
+ self.dtype = dtype
# Build chain of convs
for i in range(self.nbr_convs):
@@ -103,7 +103,11 @@ def __init__(
)
def get_inputs(self):
- return self.inputs
+ return (
+ torch.randn(self.batches, self.in_channels[0], self.height, self.width).to(
+ self.dtype
+ ),
+ )
def forward(self, x):
for i in range(self.nbr_convs):
@@ -325,124 +329,80 @@ def forward(self, x):
# Shenanigan to get a nicer output when test fails. With unittest it looks like:
# FAIL: test_conv2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1
-testsuite = [
- ("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias),
- ("3x3_1x3x256x256_st1", conv2d_3x3_1x3x256x256_st1),
- ("3x3_1x3x12x12_st2_pd1", conv2d_3x3_1x3x12x12_st2_pd1),
- ("1x1_1x2x128x128_st1", conv2d_1x1_1x2x128x128_st1),
- ("2x2_1x1x14x13_st2_needs_adjust_pass", conv2d_2x2_1x1x14x13_st2),
- ("5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1),
- ("7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass", conv2d_7x7_1x3x16x16_st2_pd1_dl2),
- ("7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass", conv2d_7x7_1x3x15x15_st1_pd0_dl1),
- ("5x5_1x3x14x14_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x14x14_st5_pd0_dl1),
- ("5x5_1x3x9x9_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x9x9_st5_pd0_dl1),
- ("3x3_1x3x9x8_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x9x8_st3_pd0_dl1),
- ("3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x8x9_st3_pd0_dl1),
- ("3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_3x4_1x3x7x7_st3_pd0_dl1),
- ("4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_4x3_1x3x7x7_st3_pd0_dl1),
- ("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1),
- ("3x3_1x3x224x224_st2_pd1", conv2d_3x3_1x3x224x224_st2_pd1),
- ("two_conv2d_nobias", two_conv2d_nobias),
- ("two_conv2d", two_conv2d),
-]
-
-
-class TestConv2D(unittest.TestCase):
- """Tests Conv2D, both single ops and multiple Convolutions in series."""
-
- def _test_conv2d_tosa_MI_pipeline(
- self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
- ):
- (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=common.get_tosa_compile_spec(
- "TOSA-0.80+MI",
- ),
- )
- .export()
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
- .to_executorch()
- .run_method_and_compare_outputs(inputs=test_data)
- )
-
- def _test_conv2d_tosa_BI_pipeline(
- self,
- module: torch.nn.Module,
- test_data: Tuple[torch.Tensor],
- ):
- (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=common.get_tosa_compile_spec(
- "TOSA-0.80+BI",
- ),
- )
- .quantize()
- .export()
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
- .to_executorch()
- .run_method_and_compare_outputs(inputs=test_data, qtol=1)
- )
-
- def _test_conv2d_ethosu_BI_pipeline(
- self,
- compile_spec: CompileSpec,
- module: torch.nn.Module,
- test_data: Tuple[torch.Tensor],
- ):
- tester = (
- ArmTester(
- module,
- example_inputs=test_data,
- compile_spec=compile_spec,
- )
- .quantize()
- .export()
- .to_edge()
- .partition()
- .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
- .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
- .to_executorch()
- .serialize()
- )
- if conftest.is_option_enabled("corstone_fvp"):
- tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
-
- @parameterized.expand(testsuite)
- def test_conv2d_tosa_MI(self, test_name, model):
- self._test_conv2d_tosa_MI_pipeline(model, model.get_inputs())
-
- @parameterized.expand(testsuite)
- def test_conv2d_tosa_BI(self, test_name, model):
- self._test_conv2d_tosa_BI_pipeline(model, model.get_inputs())
-
- # These cases have numerical issues on FVP, MLETORCH-520
- testsuite.remove(("2x2_3x2x40x40_nobias", conv2d_2x2_3x2x40x40_nobias))
- testsuite.remove(("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1))
-
- @parameterized.expand(testsuite)
- @pytest.mark.corstone_fvp
- def test_conv2d_u55_BI(self, test_name, model):
- self._test_conv2d_ethosu_BI_pipeline(
- common.get_u55_compile_spec(),
- model,
- model.get_inputs(),
- )
-
- @parameterized.expand(testsuite)
- @pytest.mark.corstone_fvp
- def test_conv2d_u85_BI(self, test_name, model):
- self._test_conv2d_ethosu_BI_pipeline(
- common.get_u85_compile_spec(),
- model,
- model.get_inputs(),
- )
+test_modules = {
+ "2x2_3x2x40x40_nobias": conv2d_2x2_3x2x40x40_nobias,
+ "3x3_1x3x256x256_st1": conv2d_3x3_1x3x256x256_st1,
+ "3x3_1x3x12x12_st2_pd1": conv2d_3x3_1x3x12x12_st2_pd1,
+ "1x1_1x2x128x128_st1": conv2d_1x1_1x2x128x128_st1,
+ "2x2_1x1x14x13_st2_needs_adjust_pass": conv2d_2x2_1x1x14x13_st2,
+ "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": conv2d_5x5_1x3x14x15_st3_pd1,
+ "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": conv2d_7x7_1x3x16x16_st2_pd1_dl2,
+ "7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": conv2d_7x7_1x3x15x15_st1_pd0_dl1,
+ "5x5_1x3x14x14_st5_pd0_dl1_needs_adjust_pass": conv2d_5x5_1x3x14x14_st5_pd0_dl1,
+ "5x5_1x3x9x9_st5_pd0_dl1_needs_adjust_pass": conv2d_5x5_1x3x9x9_st5_pd0_dl1,
+ "3x3_1x3x9x8_st3_pd0_dl1_needs_adjust_pass": conv2d_3x3_1x3x9x8_st3_pd0_dl1,
+ "3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass": conv2d_3x3_1x3x8x9_st3_pd0_dl1,
+ "3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": conv2d_3x4_1x3x7x7_st3_pd0_dl1,
+ "4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass": conv2d_4x3_1x3x7x7_st3_pd0_dl1,
+ "5x5_3x2x128x128_st1": conv2d_5x5_3x2x128x128_st1,
+ "3x3_1x3x224x224_st2_pd1": conv2d_3x3_1x3x224x224_st2_pd1,
+ "two_conv2d_nobias": two_conv2d_nobias,
+ "two_conv2d": two_conv2d,
+}
+
+fvp_xfails = {
+ "2x2_3x2x40x40_nobias": "MLETORCH-520: Numerical issues on FVP.",
+ "5x5_3x2x128x128_st1": "MLETORCH-520: Numerical issues on FVP.",
+}
+input_t = Tuple[torch.Tensor]
+
+
+@common.parametrize("test_module", test_modules)
+def test_conv2d_tosa_MI(test_module):
+ pipeline = TosaPipelineMI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_module", test_modules)
+def test_conv2d_tosa_BI(test_module):
+ pipeline = TosaPipelineBI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op
+ )
+ pipeline.change_args("run_method_and_compare_outputs", qtol=1)
+ pipeline.run()
+
+
+@common.parametrize("test_module", test_modules)
+def test_conv2d_u55_BI(test_module):
+ pipeline = EthosU55PipelineBI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_module", test_modules)
+def test_conv2d_u85_BI(test_module):
+ pipeline = EthosU85PipelineBI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_module", test_modules, fvp_xfails)
+@common.SkipIfNoCorstone300
+def test_conv2d_u55_BI_on_fvp(test_module):
+ pipeline = EthosU55PipelineBI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
+
+
+@common.parametrize("test_module", test_modules, fvp_xfails)
+@common.SkipIfNoCorstone320
+def test_conv2d_u85_BI_on_fvp(test_module):
+ pipeline = EthosU85PipelineBI[input_t](
+ test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True
+ )
+ pipeline.run()
diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py
index 8352727a1c..f6e13a2222 100644
--- a/backends/arm/test/ops/test_conv_combos.py
+++ b/backends/arm/test/ops/test_conv_combos.py
@@ -16,6 +16,7 @@
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized
+from torch.nn.parameter import Parameter
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@@ -112,12 +113,16 @@ class ComboConvBatchnormRelu6(torch.nn.Module):
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
]
- def __init__(self):
+ def __init__(self, affine: bool):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
- self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=False)
+ self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
+ self.batch_norm2d.running_mean = torch.rand(3)
+ self.batch_norm2d.running_var = torch.rand(3)
+ self.batch_norm2d.weight = Parameter(torch.rand(3))
+ self.batch_norm2d.bias = Parameter(torch.rand(3))
self.relu6 = torch.nn.ReLU6()
def get_inputs(self) -> Tuple[torch.Tensor]:
@@ -289,24 +294,30 @@ def test_conv_meandim_u85_BI(self):
##############################
## Conv + batch norm + relu ##
##############################
- def test_conv_batchnorm_relu6_tosa_MI(self):
- model = ComboConvBatchnormRelu6()
+ affine_params = [("affine", True), ("_no_affine", False)]
+
+ @parameterized.expand(affine_params)
+ def test_conv_batchnorm_relu6_tosa_MI(self, test_suffix, affine):
+ model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
- def test_conv_batchnorm_relu6_tosa_BI(self):
- model = ComboConvBatchnormRelu6()
+ @parameterized.expand(affine_params)
+ def test_conv_batchnorm_relu6_tosa_BI(self, test_suffix, affine):
+ model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
+ @parameterized.expand(affine_params)
@pytest.mark.corstone_fvp
- def test_conv_batchnorm_relu6_u55_BI(self):
- model = ComboConvBatchnormRelu6()
+ def test_conv_batchnorm_relu6_u55_BI(self, test_suffix, affine):
+ model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_ethos_BI_pipeline(
model, common.get_u55_compile_spec(), model.get_inputs()
)
+ @parameterized.expand(affine_params)
@pytest.mark.corstone_fvp
- def test_conv_batchnorm_relu_u85_BI(self):
- model = ComboConvBatchnormRelu6()
+ def test_conv_batchnorm_relu_u85_BI(self, test_suffix, affine):
+ model = ComboConvBatchnormRelu6(affine)
self._test_conv_combo_ethos_BI_pipeline(
model,
common.get_u85_compile_spec(),
@@ -353,8 +364,7 @@ def test_block_bottleneck_residual_tosa_MI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
- # TODO: Investigate flakyness (MLTORCH-307)
- @unittest.skip(reason="Skiped due to flakyness (MLTORCH-307)")
+ @pytest.mark.flaky # TODO: Investigate flakyness (MLTORCH-307)
def test_block_bottleneck_residual_tosa_BI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py
index b8d69c89f1..59ce628693 100644
--- a/backends/arm/test/ops/test_depthwise_conv.py
+++ b/backends/arm/test/ops/test_depthwise_conv.py
@@ -252,8 +252,8 @@ def _test_dw_conv_ethos_BI_pipeline(
def test_dw_conv_tosa_MI(self, test_name: str, model: torch.nn.Module):
self._test_dw_conv_tosa_MI_pipeline(model, model.get_inputs())
- # TODO: Investigate flakyness (MLTORCH-307)
@parameterized.expand(testsuite_conv1d + testsuite_conv2d)
+ @pytest.mark.flaky # TODO: Investigate flakyness (MLTORCH-307)
def test_dw_conv_tosa_BI(self, test_name: str, model: torch.nn.Module):
self._test_dw_conv_tosa_BI_pipeline(model, model.get_inputs())
diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py
index 116f5d64e8..d0807f3db0 100644
--- a/backends/arm/test/ops/test_expand.py
+++ b/backends/arm/test/ops/test_expand.py
@@ -37,15 +37,17 @@ class Expand(torch.nn.Module):
test_parameters = [
(torch.rand(1), (2,)),
(torch.randn(1, 4), (1, -1)),
- (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
(torch.randn(1), (2, 2, 4)),
- (torch.rand(3, 2, 4, 1), (-1, -1, -1, 3)),
+ (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
(torch.randn(1, 1, 192), (1, -1, -1)),
+ (torch.randn(1, 1), (1, 2, 2, 4)),
+ (torch.randn(1, 1), (2, 2, 2, 4)),
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
+ (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
]
- def forward(self, x: torch.Tensor, multiples: Sequence):
- return x.expand(multiples)
+ def forward(self, x: torch.Tensor, m: Sequence):
+ return x.expand(m)
def _test_expand_tosa_MI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
(
@@ -113,20 +115,34 @@ def test_expand_tosa_MI(self, test_input, multiples):
def test_expand_tosa_BI(self, test_input, multiples):
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
- @parameterized.expand(Expand.test_parameters)
+ @parameterized.expand(Expand.test_parameters[:-3])
@pytest.mark.corstone_fvp
- @conftest.expectedFailureOnFVP
def test_expand_u55_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
- @parameterized.expand(Expand.test_parameters)
+ # MLETORCH-629: Expand does not work on FVP with batch>1
+ @parameterized.expand(Expand.test_parameters[-3:])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
+ def test_expand_u55_BI_xfails(self, test_input, multiples):
+ self._test_expand_ethosu_BI_pipeline(
+ common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
+ )
+
+ @parameterized.expand(Expand.test_parameters[:-3])
+ @pytest.mark.corstone_fvp
def test_expand_u85_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)
+
+ # MLETORCH-629: Expand does not work on FVP with batch>1
+ @parameterized.expand(Expand.test_parameters[-3:])
+ @pytest.mark.corstone_fvp
+ @conftest.expectedFailureOnFVP
+ def test_expand_u85_BI_xfails(self, test_input, multiples):
+ self._test_expand_ethosu_BI_pipeline(
+ common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
+ )
diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py
index fc82fa4dd7..586e6bd4db 100644
--- a/backends/arm/test/ops/test_full.py
+++ b/backends/arm/test/ops/test_full.py
@@ -143,20 +143,16 @@ def test_full_tosa_MI(self, test_tensor: Tuple):
def test_full_tosa_BI(self, test_tensor: Tuple):
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor)
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
@parameterized.expand(AddVariableFull.test_parameters)
@pytest.mark.corstone_fvp
- @conftest.expectedFailureOnFVP
def test_full_u55_BI(self, test_tensor: Tuple):
self._test_full_tosa_u55_pipeline(
self.AddVariableFull(),
test_tensor,
)
- # Mismatch in provided number of inputs and model signature, MLETORCH 519
@parameterized.expand(AddVariableFull.test_parameters)
@pytest.mark.corstone_fvp
- @conftest.expectedFailureOnFVP
def test_full_u85_BI(self, test_tensor: Tuple):
self._test_full_tosa_u85_pipeline(
self.AddVariableFull(),
diff --git a/backends/arm/test/ops/test_hardsigmoid.py b/backends/arm/test/ops/test_hardsigmoid.py
new file mode 100644
index 0000000000..f73a995b12
--- /dev/null
+++ b/backends/arm/test/ops/test_hardsigmoid.py
@@ -0,0 +1,128 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+
+from typing import Tuple
+
+import pytest
+import torch
+
+from executorch.backends.arm.test import common, conftest
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+from executorch.exir.backend.compile_spec_schema import CompileSpec
+from parameterized import parameterized
+
+
+test_data_suite = [
+ # (test_name, test_data)
+ ("zeros", torch.zeros(1, 10, 10, 10)),
+ ("ones", torch.ones(10, 10, 10)),
+ ("rand", torch.rand(10, 10) - 0.5),
+ ("randn_pos", torch.randn(10) + 10),
+ ("randn_neg", torch.randn(10) - 10),
+ ("ramp", torch.arange(-16, 16, 0.2)),
+]
+
+
+class TestHardsigmoid(unittest.TestCase):
+ class Hardsigmoid(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hardsigmoid = torch.nn.Hardsigmoid()
+
+ def forward(self, x):
+ return self.hardsigmoid(x)
+
+ def _test_hardsigmoid_tosa_MI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
+ ):
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
+ )
+ .export()
+ .check(["torch.ops.aten.hardsigmoid.default"])
+ .check_not(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_hardsigmoid_tosa_BI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple
+ ):
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
+ )
+ .quantize()
+ .export()
+ .check(["torch.ops.aten.hardsigmoid.default"])
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_hardsigmoid_tosa_ethos_BI_pipeline(
+ self,
+ compile_spec: list[CompileSpec],
+ module: torch.nn.Module,
+ test_data: Tuple[torch.tensor],
+ ):
+ tester = (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=compile_spec,
+ )
+ .quantize()
+ .export()
+ .check_count({"torch.ops.aten.hardsigmoid.default": 1})
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .serialize()
+ )
+ if conftest.is_option_enabled("corstone_fvp"):
+ tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
+
+ @parameterized.expand(test_data_suite)
+ def test_hardsigmoid_tosa_MI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ ):
+ self._test_hardsigmoid_tosa_MI_pipeline(self.Hardsigmoid(), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ def test_hardsigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardsigmoid_tosa_BI_pipeline(self.Hardsigmoid(), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_hardsigmoid_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardsigmoid_tosa_ethos_BI_pipeline(
+ common.get_u55_compile_spec(), self.Hardsigmoid(), (test_data,)
+ )
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_hardsigmoid_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardsigmoid_tosa_ethos_BI_pipeline(
+ common.get_u85_compile_spec(), self.Hardsigmoid(), (test_data,)
+ )
diff --git a/backends/arm/test/ops/test_hardswish.py b/backends/arm/test/ops/test_hardswish.py
new file mode 100644
index 0000000000..81aba540e3
--- /dev/null
+++ b/backends/arm/test/ops/test_hardswish.py
@@ -0,0 +1,128 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+
+from typing import Tuple
+
+import pytest
+import torch
+
+from executorch.backends.arm.test import common, conftest
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+from executorch.exir.backend.compile_spec_schema import CompileSpec
+from parameterized import parameterized
+
+
+test_data_suite = [
+ # (test_name, test_data)
+ ("zeros", torch.zeros(1, 10, 10, 10)),
+ ("ones", torch.ones(10, 10, 10)),
+ ("rand", torch.rand(10, 10) - 0.5),
+ ("randn_pos", torch.randn(10) + 10),
+ ("randn_neg", torch.randn(10) - 10),
+ ("ramp", torch.arange(-16, 16, 0.2)),
+]
+
+
+class TestHardswish(unittest.TestCase):
+ class Hardswish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hardswish = torch.nn.Hardswish()
+
+ def forward(self, x):
+ return self.hardswish(x)
+
+ def _test_hardswish_tosa_MI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
+ ):
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
+ )
+ .export()
+ .check(["torch.ops.aten.hardswish.default"])
+ .check_not(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_hardswish_tosa_BI_pipeline(
+ self, module: torch.nn.Module, test_data: Tuple
+ ):
+ (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
+ )
+ .quantize()
+ .export()
+ .check(["torch.ops.aten.hardswish.default"])
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .run_method_and_compare_outputs(inputs=test_data)
+ )
+
+ def _test_hardswish_tosa_ethos_BI_pipeline(
+ self,
+ compile_spec: list[CompileSpec],
+ module: torch.nn.Module,
+ test_data: Tuple[torch.tensor],
+ ):
+ tester = (
+ ArmTester(
+ module,
+ example_inputs=test_data,
+ compile_spec=compile_spec,
+ )
+ .quantize()
+ .export()
+ .check_count({"torch.ops.aten.hardswish.default": 1})
+ .check(["torch.ops.quantized_decomposed"])
+ .to_edge_transform_and_lower()
+ .check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
+ .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
+ .to_executorch()
+ .serialize()
+ )
+ if conftest.is_option_enabled("corstone_fvp"):
+ tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
+
+ @parameterized.expand(test_data_suite)
+ def test_hardswish_tosa_MI(
+ self,
+ test_name: str,
+ test_data: torch.Tensor,
+ ):
+ self._test_hardswish_tosa_MI_pipeline(self.Hardswish(), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ def test_hardswish_tosa_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardswish_tosa_BI_pipeline(self.Hardswish(), (test_data,))
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_hardswish_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardswish_tosa_ethos_BI_pipeline(
+ common.get_u55_compile_spec(), self.Hardswish(), (test_data,)
+ )
+
+ @parameterized.expand(test_data_suite)
+ @pytest.mark.corstone_fvp
+ def test_hardswish_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
+ self._test_hardswish_tosa_ethos_BI_pipeline(
+ common.get_u85_compile_spec(), self.Hardswish(), (test_data,)
+ )
diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py
index c287f51ebc..82f0af8dcf 100644
--- a/backends/arm/test/ops/test_layer_norm.py
+++ b/backends/arm/test/ops/test_layer_norm.py
@@ -109,7 +109,7 @@ def _test_layernorm_tosa_BI_pipeline(
.partition()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
- .run_method_and_compare_outputs(inputs=test_data)
+ .run_method_and_compare_outputs(qtol=1, inputs=test_data)
)
def _test_layernorm_ethosu_BI_pipeline(
diff --git a/backends/arm/test/ops/test_logsoftmax.py b/backends/arm/test/ops/test_logsoftmax.py
index d1581423a0..f34d4afbb5 100644
--- a/backends/arm/test/ops/test_logsoftmax.py
+++ b/backends/arm/test/ops/test_logsoftmax.py
@@ -6,7 +6,9 @@
import unittest
-from typing import Tuple
+from typing import Callable, Tuple
+
+import pytest
import torch
from executorch.backends.arm.test import common
@@ -15,27 +17,27 @@
from parameterized import parameterized
-test_data_suite = [
+test_data_generators = [
# (test_name, test_data, dim)
- ("zeros", torch.zeros(10, 8, 5, 2), 0),
- ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
- ("ones", torch.ones(10, 10), 1),
- ("ones_neg_dim", torch.ones(10, 3, 4), -1),
- ("rand", torch.rand(1, 2, 5, 8), 2),
- ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
- ("randn", torch.randn(10, 10, 10, 10), 3),
- ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
+ lambda: ("zeros", torch.zeros(10, 8, 5, 2), 0),
+ lambda: ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
+ lambda: ("ones", torch.ones(10, 10), 1),
+ lambda: ("ones_neg_dim", torch.ones(10, 3, 4), -1),
+ lambda: ("rand", torch.rand(1, 2, 5, 8), 2),
+ lambda: ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
+ lambda: ("randn", torch.randn(10, 10, 10, 10), 3),
+ lambda: ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
]
-test_data_suite_u55 = [
+test_data_generators_u55 = [
# (test_name, test_data, dim)
- ("ones", torch.ones(10, 10), 1),
- ("ones_neg_dim", torch.ones(10, 3, 4), -1),
- ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
- ("zeros", torch.zeros(10, 8, 5, 2), 0),
- ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
- ("rand", torch.rand(1, 2, 5, 8), 2),
- ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
- ("randn", torch.randn(10, 10, 10, 10), 3),
+ lambda: ("ones", torch.ones(10, 10), 1),
+ lambda: ("ones_neg_dim", torch.ones(10, 3, 4), -1),
+ lambda: ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
+ lambda: ("zeros", torch.zeros(10, 8, 5, 2), 0),
+ lambda: ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
+ lambda: ("rand", torch.rand(1, 2, 5, 8), 2),
+ lambda: ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
+ lambda: ("randn", torch.randn(10, 10, 10, 10), 3),
]
@@ -128,42 +130,29 @@ def _test_logsoftmax_tosa_u85_BI_pipeline(
common.get_u85_compile_spec(), module, test_data
)
- @parameterized.expand(test_data_suite)
- def test_logsoftmax_tosa_MI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ def test_logsoftmax_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
- @parameterized.expand(test_data_suite)
- def test_logsoftmax_tosa_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_logsoftmax_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
- @parameterized.expand(test_data_suite_u55)
- def test_logsoftmax_tosa_u55_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators_u55)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_logsoftmax_tosa_u55_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_logsoftmax_tosa_u55_BI_pipeline(
self.LogSoftmax(dim=dim), (test_data,)
)
- @parameterized.expand(test_data_suite)
- def test_logsoftmax_tosa_u85_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_logsoftmax_tosa_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_logsoftmax_tosa_u85_BI_pipeline(
self.LogSoftmax(dim=dim), (test_data,)
)
diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py
index 1fe2c20148..a255496d51 100644
--- a/backends/arm/test/ops/test_maximum.py
+++ b/backends/arm/test/ops/test_maximum.py
@@ -109,7 +109,6 @@ def test_maximum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
self._test_maximum_tosa_BI_pipeline(self.Maximum(), test_data)
@parameterized.expand(Maximum.test_parameters)
- @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513
def test_maximum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
tester = self._test_maximum_ethos_BI_pipeline(
diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py
index d455ca1d43..04693a4643 100644
--- a/backends/arm/test/ops/test_minimum.py
+++ b/backends/arm/test/ops/test_minimum.py
@@ -109,7 +109,6 @@ def test_minimum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
self._test_minimum_tosa_BI_pipeline(self.Minimum(), test_data)
@parameterized.expand(Minimum.test_parameters)
- @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513
def test_minimum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
test_data = (operand1, operand2)
tester = self._test_minimum_ethos_BI_pipeline(
diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py
index 5fa28076aa..d9b58da904 100644
--- a/backends/arm/test/ops/test_mm.py
+++ b/backends/arm/test/ops/test_mm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -7,8 +7,9 @@
import logging
import unittest
-from typing import Tuple
+from typing import Callable, Tuple
+import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
@@ -18,30 +19,28 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
-torch.manual_seed(0)
-
class TestMM(unittest.TestCase):
"""Tests MatMul"""
class MM(torch.nn.Module):
- test_parameters = [
- (torch.rand(3, 5), torch.rand(5, 2)),
- (torch.rand(1, 1), torch.rand(1, 1)),
- (torch.ones(55, 3), torch.ones(3, 44)),
- (10000 * torch.randn(1, 10), torch.randn(10, 5)),
- (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
+ test_data_generators = [
+ lambda: (torch.rand(3, 5), torch.rand(5, 2)),
+ lambda: (torch.rand(1, 1), torch.rand(1, 1)),
+ lambda: (torch.ones(55, 3), torch.ones(3, 44)),
+ lambda: (10000 * torch.randn(1, 10), torch.randn(10, 5)),
+ lambda: (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
]
def forward(self, x, y):
return torch.mm(x, y)
class MMSingleInput(torch.nn.Module):
- test_parameters = [
- (torch.rand(3, 3),),
- (torch.ones(128, 128),),
- (10000 * torch.randn(25, 25),),
- (5 + 5 * torch.randn(64, 64),),
+ test_data_generators = [
+ lambda: (torch.rand(3, 3),),
+ lambda: (torch.ones(128, 128),),
+ lambda: (10000 * torch.randn(25, 25),),
+ lambda: (5 + 5 * torch.randn(64, 64),),
]
def forward(self, x):
@@ -110,54 +109,55 @@ def _test_mm_ethosu_BI_pipeline(
.to_executorch()
)
- @parameterized.expand(MM.test_parameters)
- def test_mm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(MM.test_data_generators)
+ def test_mm_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
- @parameterized.expand(MMSingleInput.test_parameters)
- def test_mm_single_input_tosa_MI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ @parameterized.expand(MMSingleInput.test_data_generators)
+ def test_mm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
- @parameterized.expand(MM.test_parameters)
- def test_mm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(MM.test_data_generators)
+ @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
+ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
- @parameterized.expand(MMSingleInput.test_parameters)
- def test_mm_single_input_tosa_BI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ @parameterized.expand(MMSingleInput.test_data_generators)
+ def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
# Expected to fail with error: CPU performance estimation for "MatMul" not implemented
- @parameterized.expand(MM.test_parameters)
+ @parameterized.expand(MM.test_data_generators)
@unittest.expectedFailure
- def test_mm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ def test_mm_u55_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.MM(), test_data
)
# Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
- @parameterized.expand(MMSingleInput.test_parameters)
+ @parameterized.expand(MMSingleInput.test_data_generators)
@unittest.expectedFailure
- def test_mm_single_input_u55_BI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.MMSingleInput(), test_data
)
- @parameterized.expand(MM.test_parameters)
- def test_mm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
- test_data = (operand1, operand2)
+ @parameterized.expand(MM.test_data_generators)
+ def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.MM(), test_data
)
- @parameterized.expand(MMSingleInput.test_parameters)
- def test_mm_single_input_u85_BI(self, operand1: torch.Tensor):
- test_data = (operand1,)
+ @parameterized.expand(MMSingleInput.test_data_generators)
+ def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.MMSingleInput(), test_data
)
diff --git a/backends/arm/test/ops/test_softmax.py b/backends/arm/test/ops/test_softmax.py
index 794f6b791f..c60da18594 100644
--- a/backends/arm/test/ops/test_softmax.py
+++ b/backends/arm/test/ops/test_softmax.py
@@ -7,7 +7,9 @@
import unittest
-from typing import Tuple
+from typing import Callable, Tuple
+
+import pytest
import torch
from executorch.backends.arm.test import common
@@ -16,28 +18,28 @@
from parameterized import parameterized
-test_data_suite = [
+test_data_generators = [
# (test_name, test_data, dim)
- ("zeros", torch.zeros(10, 8, 5, 2), 0),
- ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
- ("ones", torch.ones(10, 10), 1),
- ("ones_neg_dim", torch.ones(10, 3, 4), -1),
- ("rand", torch.rand(1, 2, 5, 8), 2),
- ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
- ("randn", torch.randn(10, 10, 10, 10), 3),
- ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
+ lambda: ("zeros", torch.zeros(10, 8, 5, 2), 0),
+ lambda: ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
+ lambda: ("ones", torch.ones(10, 10), 1),
+ lambda: ("ones_neg_dim", torch.ones(10, 3, 4), -1),
+ lambda: ("rand", torch.rand(1, 2, 5, 8), 2),
+ lambda: ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
+ lambda: ("randn", torch.randn(10, 10, 10, 10), 3),
+ lambda: ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
]
-test_data_suite_u55 = [
+test_data_generators_u55 = [
# (test_name, test_data, dim)
- ("ones", torch.ones(10, 10), 1),
- ("ones_neg_dim", torch.ones(10, 3, 4), -1),
- ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
- ("zeros", torch.zeros(10, 8, 5, 2), 0),
- ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
- ("rand", torch.rand(1, 2, 5, 8), 2),
- ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
- ("randn", torch.randn(10, 10, 10, 10), 3),
+ lambda: ("ones", torch.ones(10, 10), 1),
+ lambda: ("ones_neg_dim", torch.ones(10, 3, 4), -1),
+ lambda: ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
+ lambda: ("zeros", torch.zeros(10, 8, 5, 2), 0),
+ lambda: ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
+ lambda: ("rand", torch.rand(1, 2, 5, 8), 2),
+ lambda: ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
+ lambda: ("randn", torch.randn(10, 10, 10, 10), 3),
]
@@ -130,38 +132,25 @@ def _test_softmax_tosa_u85_BI_pipeline(
common.get_u85_compile_spec(), module, test_data
)
- @parameterized.expand(test_data_suite)
- def test_softmax_tosa_MI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ def test_softmax_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_softmax_tosa_MI_pipeline(self.Softmax(dim=dim), (test_data,))
- @parameterized.expand(test_data_suite)
- def test_softmax_tosa_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_softmax_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,))
- @parameterized.expand(test_data_suite_u55)
- def test_softmax_tosa_u55_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators_u55)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_softmax_tosa_u55_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,))
- @parameterized.expand(test_data_suite)
- def test_softmax_tosa_u85_BI(
- self,
- test_name: str,
- test_data: torch.Tensor,
- dim: int,
- ):
+ @parameterized.expand(test_data_generators)
+ @pytest.mark.flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
+ def test_softmax_tosa_u85_BI(self, test_data_generator: Callable[[], Tuple]):
+ test_name, test_data, dim = test_data_generator()
self._test_softmax_tosa_u85_BI_pipeline(self.Softmax(dim=dim), (test_data,))
diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py
new file mode 100644
index 0000000000..fdfab1f3af
--- /dev/null
+++ b/backends/arm/test/passes/test_cast_int64_pass.py
@@ -0,0 +1,44 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+
+import torch
+from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
+
+from executorch.backends.arm.test import common
+
+from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
+
+
+class Int64Model(torch.nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x + 3
+
+ def get_inputs(self):
+ return (torch.rand(4),)
+
+
+class TestCastInt64Pass(unittest.TestCase):
+
+ def test_int64_model(self):
+ module = Int64Model()
+ test_pass_stage = RunPasses(passes_with_exported_program=[CastInt64ToInt32Pass])
+ tester = (
+ ArmTester(
+ module,
+ example_inputs=module.get_inputs(),
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
+ )
+ .export()
+ .to_edge()
+ .run_passes(test_pass_stage)
+ .run_method_and_compare_outputs()
+ )
+ exported_program = tester.get_artifact("RunPasses").exported_program()
+ for state in exported_program.state_dict:
+ assert exported_program.state_dict[state].dtype != torch.int64
diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py
new file mode 100644
index 0000000000..09f8f578fc
--- /dev/null
+++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py
@@ -0,0 +1,158 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import unittest
+
+import torch
+from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
+from executorch.backends.arm.test import common
+from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
+from parameterized import parameterized
+
+
+class MergeOneOfTwoBN(torch.nn.Module):
+ ops_before_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
+ }
+ ops_after_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 1,
+ }
+
+ def __init__(self, affine: bool):
+ super().__init__()
+ self.conv2d = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
+ )
+ self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
+ self.batch_norm2d.running_mean = torch.rand(3)
+ self.batch_norm2d.running_var = torch.rand(3)
+ if affine:
+ self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
+ self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
+ self.relu6 = torch.nn.ReLU6()
+
+ def get_inputs(self) -> tuple[torch.Tensor]:
+ return (torch.randn(1, 3, 256, 256),)
+
+ def forward(self, x):
+ x = self.conv2d(x)
+ x = self.batch_norm2d(x)
+ x = self.relu6(x)
+ x = self.batch_norm2d(x)
+ return x
+
+
+class MergeTwosOfTwoBN(torch.nn.Module):
+ ops_before_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 2,
+ }
+ ops_after_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 0,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 2,
+ }
+
+ def __init__(self, affine: bool):
+ super().__init__()
+ self.conv2d = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
+ )
+ self.conv2d2 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
+ )
+ self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
+ self.batch_norm2d.running_mean = torch.rand(3)
+ self.batch_norm2d.running_var = torch.rand(3)
+ if affine:
+ self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
+ self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
+ self.relu6 = torch.nn.ReLU6()
+
+ def get_inputs(self) -> tuple[torch.Tensor]:
+ return (torch.randn(1, 3, 256, 256),)
+
+ def forward(self, x):
+ x = self.conv2d(x)
+ x = self.batch_norm2d(x)
+ x = self.relu6(x)
+ x = self.conv2d2(x)
+ x = self.batch_norm2d(x)
+ return x
+
+
+class MergeNoBN(torch.nn.Module):
+ ops_before_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
+ }
+ ops_after_pass = {
+ "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2,
+ "executorch_exir_dialects_edge__ops_aten_convolution_default": 3,
+ }
+
+ def __init__(self, affine: bool):
+ super().__init__()
+ self.conv2d = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
+ )
+ self.conv2d2 = torch.nn.Conv2d(
+ in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
+ )
+ self.batch_norm2d = torch.nn.BatchNorm2d(3, affine=affine)
+ self.batch_norm2d.running_mean = torch.rand(3)
+ self.batch_norm2d.running_var = torch.rand(3)
+ if affine:
+ self.batch_norm2d.weight = torch.nn.Parameter(torch.rand(3))
+ self.batch_norm2d.bias = torch.nn.Parameter(torch.rand(3))
+ self.relu6 = torch.nn.ReLU6()
+
+ def get_inputs(self) -> tuple[torch.Tensor]:
+ return (torch.randn(1, 3, 256, 256),)
+
+ def forward(self, x):
+ x1 = self.conv2d(x)
+ x = self.batch_norm2d(x1) # Can't be fused since x1 has multiple users
+ x = self.relu6(x)
+ y = self.conv2d2(x1)
+ z = self.conv2d2(x)
+ a = self.batch_norm2d(
+ y
+ ) # Can't be fused since paramters of conv2d2 have multiple users.
+
+ return z, a
+
+
+modules = [
+ MergeOneOfTwoBN(True),
+ MergeOneOfTwoBN(False),
+ MergeTwosOfTwoBN(True),
+ MergeNoBN(True),
+]
+
+
+class TestFuseBatchnormPass(unittest.TestCase):
+
+ @parameterized.expand(modules)
+ def test_fuse_batchnorm_tosa_MI(self, module):
+ """Test various cases where the batchnorm should and shouldn't be fused."""
+ inputs = module.get_inputs()
+ test_pass_stage = RunPasses(passes_with_exported_program=[FuseBatchnorm2DPass])
+ (
+ (
+ ArmTester(
+ module,
+ example_inputs=inputs,
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
+ )
+ .export()
+ .to_edge()
+ .check_count(module.ops_before_pass)
+ .run_passes(test_pass_stage)
+ .check_count(module.ops_after_pass)
+ .run_method_and_compare_outputs()
+ )
+ )
diff --git a/backends/arm/test/passes/test_insert_table_ops_pass.py b/backends/arm/test/passes/test_insert_table_ops_pass.py
new file mode 100644
index 0000000000..c0a9235fa6
--- /dev/null
+++ b/backends/arm/test/passes/test_insert_table_ops_pass.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+
+import torch
+from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
+ FoldAndAnnotateQParamsPass,
+)
+from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
+
+from executorch.backends.arm.test import common
+
+from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses
+
+
+class Sigmoid(torch.nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x.sigmoid()
+
+ def get_inputs(self):
+ return (torch.rand(4),)
+
+
+class TestInsertTablePass(unittest.TestCase):
+
+ def test_insert_table_tosa_BI(self):
+ module = Sigmoid()
+ test_pass_stage = RunPasses(
+ [FoldAndAnnotateQParamsPass],
+ passes_with_exported_program=[InsertTableOpsPass],
+ )
+ (
+ ArmTester(
+ module,
+ example_inputs=module.get_inputs(),
+ compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
+ )
+ .quantize()
+ .export()
+ .to_edge()
+ .run_passes(test_pass_stage)
+ .check("tosa._table")
+ .check_count(
+ {
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1,
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1,
+ }
+ )
+ .check_not(["aten_sigmoid_default"])
+ )
diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py
new file mode 100644
index 0000000000..e31007f1ed
--- /dev/null
+++ b/backends/arm/test/passes/test_ioquantization_pass.py
@@ -0,0 +1,70 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import unittest
+
+import torch
+
+from executorch.backends.arm.test import common
+
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
+
+
+class SimpleModel(torch.nn.Module):
+ def forward(self, x, y):
+ return x + y
+
+ def get_inputs(self):
+ a = torch.rand(1, 2, 2, 1)
+ b = torch.rand(1, 2, 2, 1)
+ return (a, b)
+
+
+class TestIOQuantizationPass(unittest.TestCase):
+ """
+ Test the executorch/exir/passes/quanize_io_pass pass works(meaning we don't get Q/DQ nodes) on a simple model
+ """
+
+ def test_ioquantisation_pass(self):
+ model = SimpleModel()
+ tester = (
+ ArmTester(
+ model,
+ example_inputs=model.get_inputs(),
+ compile_spec=common.get_u55_compile_spec(),
+ )
+ .quantize()
+ .export()
+ .to_edge()
+ .check_count(
+ {
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3
+ }
+ )
+ .check_count(
+ {
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3
+ }
+ )
+ .partition()
+ .check_count(
+ {
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2
+ }
+ )
+ .check_count(
+ {
+ "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1
+ }
+ )
+ )
+ edge = tester.get_artifact()
+ edge.transform(
+ passes=[QuantizeInputs(edge, [0, 1]), QuantizeOutputs(edge, [0])]
+ )
+ tester.check_not(["edge__ops_quantized_decomposed_quantize_per_tensor"])
+ tester.check_not(["edge__ops_quantized_decomposed_dequantize_per_tensor"])
diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py
index 6996d53e91..577e114be0 100644
--- a/backends/arm/test/runner_utils.py
+++ b/backends/arm/test/runner_utils.py
@@ -65,16 +65,7 @@ def get_input_names(program: ExportedProgram) -> list[str]:
Returns:
A list of strings with the names of the model input.
"""
- input_names = []
-
- # E.g. bias and weights are 'placeholders' as well. This is used to
- # get only the use inputs.
- usr_inputs = program.graph_signature.user_inputs
- for node in program.graph.nodes:
- if node.op == "placeholder" and node.name in usr_inputs:
- input_names.append(node.name)
-
- return input_names
+ return [spec.arg.name for spec in program.graph_signature.input_specs]
def get_input_quantization_params(
@@ -178,7 +169,7 @@ def _tosa_dispatch(self, lowered_backend_module: LoweredBackendModule, inputs):
return run_tosa_graph(tosa_buffer, tosa_version, inputs)
def __torch_function__(self, func, types, args=..., kwargs=None):
- if isinstance(func, torch._higher_order_ops.executorch_call_delegate.ExecutorchCallDelegate): # type: ignore
+ if func is torch._higher_order_ops.executorch_call_delegate:
lowered_backend_module = cast(LoweredBackendModule, args[0])
if lowered_backend_module.backend_id == "ArmBackend":
return self._tosa_dispatch(lowered_backend_module, args[1:])
@@ -334,13 +325,16 @@ def run_corstone(
def prep_data_for_save(
- data: torch.Tensor,
+ data,
input_name: str,
quant_param: Optional[QuantizationParams] = None,
):
- data_np = np.array(data.detach(), order="C").astype(
- torch_to_numpy_dtype_dict[data.dtype]
- )
+ if isinstance(data, torch.Tensor):
+ data_np = np.array(data.detach(), order="C").astype(
+ torch_to_numpy_dtype_dict[data.dtype]
+ )
+ else:
+ data_np = np.array(data)
if quant_param is not None:
assert quant_param.node_name in input_name, (
f"The quantization params name '{quant_param.node_name}' does not "
@@ -492,6 +486,47 @@ def _tosa_refmodel_loglevel(loglevel: int) -> str:
return loglevel_map[clamped_logging_level]
+def corstone300_installed() -> bool:
+ cmd = ["FVP_Corstone_SSE-300_Ethos-U55", "--version"]
+ try:
+ _run_cmd(cmd, check=True)
+ except:
+ return False
+ return True
+
+
+def corstone320_installed() -> bool:
+ cmd = ["FVP_Corstone_SSE-320", "--version"]
+ try:
+ _run_cmd(cmd, check=True)
+ except:
+ return False
+ return True
+
+
+def get_elf_path(target_board):
+ elf_path = os.path.join(
+ "cmake-out",
+ f"arm_semihosting_executor_runner_{target_board}",
+ "arm_executor_runner",
+ )
+ if not os.path.exists(elf_path):
+ raise RuntimeError(
+ f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?"
+ )
+ else:
+ return elf_path
+
+
+def arm_executor_runner_exists(target_board):
+ try:
+ get_elf_path(target_board)
+ except:
+ return False
+ else:
+ return True
+
+
def run_tosa_graph(
graph: TosaGraph,
tosa_version: TosaSpecification,
diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py
index 639cea5bae..11e7d86304 100644
--- a/backends/arm/test/tester/arm_tester.py
+++ b/backends/arm/test/tester/arm_tester.py
@@ -8,11 +8,11 @@
import os
from collections import Counter
from pprint import pformat
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
import executorch.backends.xnnpack.test.tester.tester as tester
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore[import-untyped]
import torch.fx
import torch.utils._pytree as pytree
@@ -25,6 +25,7 @@
)
from executorch.backends.arm.test.runner_utils import (
dbg_tosa_fb_to_json,
+ get_elf_path,
get_output_nodes,
get_output_quantization_params,
get_target_board,
@@ -41,10 +42,18 @@
from executorch.backends.xnnpack.test.tester import Tester
from executorch.devtools.backend_debug import get_delegation_info
-from executorch.exir import EdgeCompileConfig, ExecutorchProgramManager
+from executorch.exir import (
+ EdgeCompileConfig,
+ EdgeProgramManager,
+ ExecutorchProgramManager,
+ ExportedProgram,
+)
+from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.lowered_backend_module import LoweredBackendModule
+from executorch.exir.pass_base import ExportPass
+from executorch.exir.program._program import _update_exported_program_graph_module
from tabulate import tabulate
from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
@@ -132,11 +141,8 @@ def run_artifact(self, inputs):
inputs_flattened, _ = tree_flatten(inputs)
intermediate_path = get_intermediate_path(self.compile_spec)
target_board = get_target_board(self.compile_spec)
- elf_path = os.path.join(
- "cmake-out",
- f"arm_semihosting_executor_runner_{target_board}",
- "arm_executor_runner",
- )
+ elf_path = get_elf_path(target_board)
+
if not os.path.exists(elf_path):
raise FileNotFoundError(
f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?"
@@ -158,6 +164,44 @@ def run_artifact(self, inputs):
return super().run_artifact(inputs)
+class RunPasses(tester.RunPasses):
+
+ def __init__(
+ self,
+ pass_list: Optional[List[Type[ExportPass]]] = None,
+ pass_functions: Optional[List[Callable]] = None,
+ passes_with_exported_program: Optional[List[Type[ExportPass]]] = None,
+ ):
+ """Passes are run in the order they are passed: first pass_list, second pass_functions,
+ and lastly passes_with_exported_program."""
+ self.pass_with_exported_program = passes_with_exported_program
+ super().__init__(pass_list, pass_functions)
+
+ def run(
+ self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None
+ ) -> None:
+ if self.pass_with_exported_program is not None:
+ self.pass_functions = self.pass_functions or [] # type: ignore
+
+ # pass_function list from superclass expects functions that take in
+ # and return ExportedPrograms.
+ # Create a wrapper to fit pass_with_exported_program into this.
+ def wrap_ep_pass(ep_pass: Type[ExportPass]):
+ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
+ pass_result = ep_pass(ep).call(ep.graph_module)
+ with validation_disabled():
+ return _update_exported_program_graph_module(
+ ep, pass_result.graph_module
+ )
+
+ return wrapped_ep_pass
+
+ self.pass_functions.extend(
+ [wrap_ep_pass(ep_pass) for ep_pass in self.pass_with_exported_program]
+ )
+ super().run(artifact, inputs)
+
+
class InitialModel(tester.Stage):
def __init__(self, model: torch.nn.Module):
self.model = model
diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py
new file mode 100644
index 0000000000..bc67783ddb
--- /dev/null
+++ b/backends/arm/test/tester/test_pipeline.py
@@ -0,0 +1,369 @@
+# Copyright 2025 Arm Limited and/or its affiliates.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from typing import Any, Callable, Generic, List, TypeVar
+
+import torch
+from executorch.backends.arm.test import common
+from executorch.backends.arm.test.tester.arm_tester import ArmTester
+from executorch.exir.backend.compile_spec_schema import CompileSpec
+
+
+logger = logging.getLogger(__name__)
+T = TypeVar("T")
+""" Generic type used for test data in the pipeline. Depends on which type the operator expects."""
+
+
+class BasePipelineMaker(Generic[T]):
+ """
+ The BasePiplineMaker defines a list of stages to be applied to a torch.nn.module for lowering it in the Arm backend. To be inherited and adjusted for particular targets.
+ Importantly, the pipeline list can be modified before running the pipeline to support various pipeline extensions and debugging usecases.
+
+ Attributes:
+ module: The module which the pipeline is applied to.
+ test_data: Data used for quantizing and testing the module.
+ aten_ops: Aten dialect ops expected to be found in the graph after export.
+ exir_ops: Exir dialect ops expected to be found in the graph after to_edge.
+ compile_spec: The compile spec used in the lowering process
+ use_edge_to_transform_and_lower: Selects betweeen two possible routes for lowering the module:
+ tester.to_edge_transform_and_lower()
+ or
+ tester.to_edge().check(exir_ops).partition()
+ """
+
+ class PipelineStage:
+ """
+ Helper class to store a pipeline stage as a function call + args for calling later on.
+
+ Attributes:
+ id: name of the function to be called, used for refering to stages in the pipeline
+ func: handle to the function to be called
+ args: args used when called
+ kwargs: kwargs used when called
+ is_called: keeps track of if the function has been called
+ """
+
+ def __init__(self, func, *args, **kwargs):
+ self.id: str = func.__name__
+ self.func: Callable = func
+ self.args = args
+ self.kwargs = kwargs
+ self.is_called = False
+
+ def __call__(self):
+ if not self.is_called:
+ self.func(*self.args, **self.kwargs)
+ else:
+ raise RuntimeError(f"{self.id} called twice.")
+ self.is_called = True
+
+ def update(self, *args, **kwargs):
+ if not self.is_called:
+ self.args = args
+ self.kwargs = kwargs
+ else:
+ raise RuntimeError(f"{self.id} args updated after being called.")
+
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ test_data: T,
+ aten_ops: str | List[str],
+ exir_ops: str | List[str],
+ compile_spec: List[CompileSpec],
+ use_to_edge_transform_and_lower: bool = False,
+ ):
+
+ self.tester = ArmTester(
+ module, example_inputs=test_data, compile_spec=compile_spec
+ )
+
+ self.aten_ops = aten_ops if isinstance(aten_ops, list) else [aten_ops]
+ self.exir_ops = exir_ops if isinstance(exir_ops, list) else [exir_ops]
+ self.test_data = test_data
+ self._stages = []
+
+ self.add_stage(-1, self.tester.export)
+ self.add_stage(-1, self.tester.check, self.aten_ops)
+ if use_to_edge_transform_and_lower:
+ self.add_stage(-1, self.tester.to_edge_transform_and_lower)
+
+ else:
+ self.add_stage(-1, self.tester.to_edge)
+ self.add_stage(-1, self.tester.check, self.exir_ops)
+ self.add_stage(-1, self.tester.partition)
+ self.add_stage(-1, self.tester.check_not, self.exir_ops)
+ self.add_stage(
+ -1,
+ self.tester.check_count,
+ {"torch.ops.higher_order.executorch_call_delegate": 1},
+ )
+ self.add_stage(-1, self.tester.to_executorch)
+
+ def add_stage(self, pos: int, func: Callable, *args, **kwargs):
+ """Adds a stage defined by a function with arguments to the pipeline at index pos. Pos wraps around the list for negative values."""
+ pipeline_stage = self.PipelineStage(func, *args, **kwargs)
+ pipeline_length = len(self._stages)
+
+ if pos < 0:
+ pos = pipeline_length + (pos + 1)
+
+ if not -pipeline_length <= pos <= pipeline_length:
+ raise ValueError(
+ f"Pos must be between [-{pipeline_length}, {pipeline_length}]"
+ )
+
+ self._stages.insert(pos, pipeline_stage)
+
+ logger.debug(f"Added stage {func.__name__} to {type(self).__name__}")
+
+ return self
+
+ def pop_stage(self, pos: int):
+ """Removes and returns the stage at postion pos"""
+ return self._stages.pop(pos)
+
+ def find_pos(self, stage_id: str):
+ """Returns the position of the stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
+ for i, stage in enumerate(self._stages):
+ if stage.id == stage_id:
+ return i
+
+ raise Exception(f"Stage id {stage_id} not found in pipeline")
+
+ def add_stage_after(self, stage_id: str, func: Callable, *args, **kwargs):
+ """Adds a stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
+ pos = self.find_pos(stage_id)
+ self.add_stage(pos + 1, func, *args, **kwargs)
+ return self
+
+ def dump_artifact(self, stage_id: str):
+ """Adds a dump_artifact stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
+ self.add_stage_after(stage_id, self.tester.dump_artifact)
+ return self
+
+ def dump_operator_distribution(self, stage_id: str):
+ """Adds a dump_operator_distribution stage after the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
+ self.add_stage_after(stage_id, self.tester.dump_operator_distribution)
+ return self
+
+ def change_args(self, stage_id: str, *args, **kwargs):
+ """Updates the args to the given stage id. Note that this only finds the first stage with the given id, i.e. it should only be used with unique stages."""
+ pos = self.find_pos(stage_id)
+ pipeline_stage = self._stages[pos]
+ pipeline_stage.update(*args, **kwargs)
+ return self
+
+ def run(self):
+ """Calls each stage in order."""
+ stage_list = [stage.id for stage in self._stages]
+ logger.info(f"Running pipeline with stages:\n {stage_list}.")
+
+ for stage in self._stages:
+ try:
+ stage()
+ except Exception as e:
+ logger.error(f"\nFailure in stage <{stage.id}>: \n {str(e)}")
+ raise e
+
+
+class TosaPipelineBI(BasePipelineMaker, Generic[T]):
+ """Lowers a graph to BI TOSA spec (with quantization) and tests it with the TOSA reference model."""
+
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ test_data: Any,
+ aten_op: str,
+ exir_op: str,
+ tosa_version: str = "TOSA-0.80+BI",
+ use_to_edge_transform_and_lower: bool = False,
+ ):
+ compile_spec = common.get_tosa_compile_spec(
+ tosa_version,
+ )
+ super().__init__(
+ module,
+ test_data,
+ aten_op,
+ exir_op,
+ compile_spec,
+ use_to_edge_transform_and_lower,
+ )
+ self.add_stage(0, self.tester.quantize)
+ self.add_stage_after(
+ "quantize",
+ self.tester.check,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ remove_quant_nodes_stage = (
+ "to_edge_transform_and_lower"
+ if use_to_edge_transform_and_lower
+ else "partition"
+ )
+ self.add_stage_after(
+ remove_quant_nodes_stage,
+ self.tester.check_not,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ self.add_stage(
+ -1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
+ )
+
+
+class TosaPipelineMI(BasePipelineMaker, Generic[T]):
+ """Lowers a graph to MI TOSA spec and tests it with the TOSA reference model"""
+
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ test_data: Any,
+ aten_op: str,
+ exir_op: str,
+ tosa_version: str = "TOSA-0.80+MI",
+ use_to_edge_transform_and_lower: bool = False,
+ ):
+ compile_spec = common.get_tosa_compile_spec(
+ tosa_version,
+ )
+ super().__init__(
+ module,
+ test_data,
+ aten_op,
+ exir_op,
+ compile_spec,
+ use_to_edge_transform_and_lower,
+ )
+ self.add_stage_after(
+ "export",
+ self.tester.check_not,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ self.add_stage(
+ -1, self.tester.run_method_and_compare_outputs, inputs=self.test_data
+ )
+
+
+class EthosU55PipelineBI(BasePipelineMaker, Generic[T]):
+ """Lowers a graph to u55 BI TOSA spec and tests it on the Corstone300 FVP, if run_on_fvp is true."""
+
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ test_data: T,
+ aten_ops: str | List[str],
+ exir_ops: str | List[str],
+ run_on_fvp: bool = False,
+ use_to_edge_transform_and_lower: bool = False,
+ ):
+ compile_spec = common.get_u55_compile_spec()
+ super().__init__(
+ module,
+ test_data,
+ aten_ops,
+ exir_ops,
+ compile_spec,
+ use_to_edge_transform_and_lower,
+ )
+ self.add_stage(0, self.tester.quantize)
+ self.add_stage_after(
+ "quantize",
+ self.tester.check,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ remove_quant_nodes_stage = (
+ "to_edge_transform_and_lower"
+ if use_to_edge_transform_and_lower
+ else "partition"
+ )
+ self.add_stage_after(
+ remove_quant_nodes_stage,
+ self.tester.check_not,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ if run_on_fvp:
+ self.add_stage(-1, self.tester.serialize)
+ self.add_stage(
+ -1,
+ self.tester.run_method_and_compare_outputs,
+ qtol=1,
+ inputs=self.test_data,
+ )
+
+
+class EthosU85PipelineBI(BasePipelineMaker, Generic[T]):
+ """Lowers a graph to u85 BI TOSA spec and tests it on the Corstone320 FVP, if run_on_fvp is true."""
+
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ test_data: T,
+ aten_ops: str | List[str],
+ exir_ops: str | List[str],
+ run_on_fvp: bool = False,
+ use_to_edge_transform_and_lower: bool = False,
+ ):
+ compile_spec = common.get_u85_compile_spec()
+ super().__init__(
+ module,
+ test_data,
+ aten_ops,
+ exir_ops,
+ compile_spec,
+ use_to_edge_transform_and_lower,
+ )
+ self.add_stage(0, self.tester.quantize)
+ self.add_stage_after(
+ "quantize",
+ self.tester.check,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ remove_quant_nodes_stage = (
+ "to_edge_transform_and_lower"
+ if use_to_edge_transform_and_lower
+ else "partition"
+ )
+ self.add_stage_after(
+ remove_quant_nodes_stage,
+ self.tester.check_not,
+ [
+ "torch.ops.quantized_decomposed.dequantize_per_tensor.default",
+ "torch.ops.quantized_decomposed.quantize_per_tensor.default",
+ ],
+ )
+
+ if run_on_fvp:
+ self.add_stage(-1, self.tester.serialize)
+ self.add_stage(
+ -1,
+ self.tester.run_method_and_compare_outputs,
+ qtol=1,
+ inputs=self.test_data,
+ )
diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py
index ec57bd5ce2..75d82f2a4b 100644
--- a/backends/arm/tosa_mapping.py
+++ b/backends/arm/tosa_mapping.py
@@ -1,4 +1,4 @@
-# Copyright 2023-2024 Arm Limited and/or its affiliates.
+# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
@@ -11,7 +11,7 @@
# the standardised TOSA representation.
#
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py
index 9869a08c0b..d53362cb36 100644
--- a/backends/arm/tosa_quant_utils.py
+++ b/backends/arm/tosa_quant_utils.py
@@ -10,9 +10,9 @@
import math
from typing import cast, NamedTuple
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch.fx
-import tosa.Op as TosaOp
+import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaSerializerTensor
diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py
index 9fefdbb3ff..15d29b5748 100644
--- a/backends/arm/tosa_utils.py
+++ b/backends/arm/tosa_utils.py
@@ -9,8 +9,7 @@
import os
from typing import Any
-import numpy as np
-import serializer.tosa_serializer as ts
+import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
@@ -72,45 +71,6 @@ def dbg_fail(node, tosa_graph, path):
raise RuntimeError("TOSA Internal Error on node, enable logging for further info.")
-# Helper function to match TOSA's broadcasting rank requirement
-# Ref: TOSA 0.80 specification - 1.9.3. Data Layouts from
-# https://www.mlplatform.org/tosa/tosa_spec.html
-def promote_shape(tosa_fb, arg, promoted_shape, out_dtype):
- assert np.prod(arg.shape) == np.prod(promoted_shape), "Incompatible promoted shape"
- reshape_res = tosa_fb.addIntermediate(promoted_shape, out_dtype)
- attr = ts.TosaSerializerAttribute()
- attr.ReshapeAttribute(promoted_shape)
- tosa_fb.addOperator(TosaOp.Op().RESHAPE, [arg.name], [reshape_res.name], attr)
- return reshape_res
-
-
-# Helper transpose function to match TOSA's shape requirements
-# E.g., TOSA 0.80 specification - 2.3.3 CONV2D shapes:
-# https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
-def transpose_helper(tosa_fb, input, new_order, out_dtype):
- # Check new_order's length is equal to input rank
- assert len(input.shape) == len(new_order), "Wrong shape order length"
-
- # Check no duplications
- assert len(set(new_order)) == len(new_order), "Contain duplicated dim numbers"
-
- # Check all dims are valid
- for idx in new_order:
- if idx < 0:
- assert True, "Negative dim number"
- elif idx >= len(input.shape):
- assert True, "Dim is greater than input rank"
-
- input_shape_transpoed = [input.shape[i] for i in new_order]
- attr = ts.TosaSerializerAttribute()
- attr.TransposeAttribute(new_order)
- input_transposed = tosa_fb.addIntermediate(input_shape_transpoed, out_dtype)
- tosa_fb.addOperator(
- TosaOp.Op().TRANSPOSE, [input.name], [input_transposed.name], attr
- )
- return input_transposed
-
-
def getNodeArgs(node: Node) -> list[TosaArg]:
return [TosaArg(arg) for arg in node.args]
diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py
index f8aeab25ba..e13f9c4df0 100644
--- a/backends/arm/util/arm_model_evaluator.py
+++ b/backends/arm/util/arm_model_evaluator.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Arm Limited and/or its affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -17,7 +17,7 @@
import torch
from torch.nn.modules import Module
from torch.utils.data import DataLoader
-from torchvision import datasets, transforms
+from torchvision import datasets, transforms # type: ignore[import-untyped]
# Logger for outputting progress for longer running evaluation
@@ -59,7 +59,7 @@ def __init__(
if tosa_output_path:
self.tosa_output_path = tosa_output_path
else:
- self.tosa_output_path = None
+ self.tosa_output_path = ""
def get_model_error(self) -> defaultdict:
"""
@@ -104,7 +104,7 @@ def get_compression_ratio(self) -> float:
return compression_ratio
- def evaluate(self) -> dict[Any]:
+ def evaluate(self) -> dict[str, Any]:
model_error_dict = self.get_model_error()
output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)}
@@ -112,7 +112,7 @@ def evaluate(self) -> dict[Any]:
if self.tosa_output_path:
# We know output_metrics["metrics"] is list since we just defined it, safe to ignore.
# pyre-ignore[16]
- output_metrics["metrics"][
+ output_metrics["metrics"][ # type: ignore[index]
"compression_ratio"
] = self.get_compression_ratio()
diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py
index bf4a274134..f9abe1c542 100644
--- a/backends/cadence/aot/compiler.py
+++ b/backends/cadence/aot/compiler.py
@@ -33,7 +33,6 @@
ExecutorchProgramManager,
to_edge,
)
-from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
@@ -57,6 +56,7 @@ def convert_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
+ dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare and convert a model using the given quantizer.
@@ -87,6 +87,10 @@ def convert_pt2(
.module()
)
+ if dump_graphs:
+ logging.info("Graph before quantization:")
+ logging.info(model_gm.graph.print_tabular())
+
# Prepare
prepared_model = prepare_pt2e(model_gm, quantizer)
@@ -96,6 +100,10 @@ def convert_pt2(
# Convert
converted_model = convert_pt2e(prepared_model)
+ if dump_graphs:
+ logging.info("Graph after quantization (before fusion):")
+ logging.info(model_gm.graph.print_tabular())
+
return converted_model
@@ -128,6 +136,7 @@ def quantize_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
+ dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare, convert and fuse the model using the given quantizer.
@@ -141,11 +150,15 @@ def quantize_pt2(
quantizer = CadenceDefaultQuantizer()
# Get converted graph module
- converted_gm = convert_pt2(model, inputs, quantizer)
+ converted_gm = convert_pt2(model, inputs, quantizer, dump_graphs)
# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)
+ if dump_graphs:
+ logging.info("Graph after quantization and fusion:")
+ logging.info(fused_gm.graph.print_tabular())
+
return fused_gm
@@ -153,7 +166,6 @@ def quantize_pt2(
def export_program(
model: torch.nn.Module,
inputs: tuple[object, ...],
- dump_graphs: bool = False,
) -> ExportedProgram:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
@@ -163,10 +175,6 @@ def export_program(
# Export the model and return it.
expo_program = export(model, inputs, strict=True)
- if dump_graphs:
- logging.info("Exported graph:")
- expo_program.graph_module.graph.print_tabular()
-
return expo_program
@@ -180,13 +188,14 @@ def export_to_edge(
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
# Export the model into an ExportedProgram.
- expo_program = export_program(model, inputs, dump_graphs=dump_graphs)
+ expo_program = export_program(model, inputs)
# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
+ _skip_dim_order=True,
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=[
torch.ops.aten._native_batch_norm_legit_functional.default,
@@ -194,18 +203,16 @@ def export_to_edge(
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.unfold.default,
torch.ops.aten.angle.default,
- # cadence replaced to_dim_order_copy with _to_copy for performance
- # skip _to_copy op to get around of dim order check
- # We should remove this op once cadence can support dim order
- exir_ops.edge.aten._to_copy.default,
],
),
constant_methods=constant_methods,
)
if dump_graphs:
- logging.info("Edge graph:")
- edge_prog_manager.exported_program().graph_module.graph.print_tabular()
+ logging.info("Graph after Edge lowering:")
+ logging.info(
+ edge_prog_manager.exported_program().graph_module.graph.print_tabular()
+ )
return edge_prog_manager
diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py
index cc304a226a..89ef821c56 100644
--- a/backends/cadence/aot/replace_ops.py
+++ b/backends/cadence/aot/replace_ops.py
@@ -11,7 +11,6 @@
# pyre-unsafe
-import copy
import math
from operator import neg
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
@@ -36,12 +35,7 @@
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
-from executorch.exir.dim_order_utils import get_memory_format
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
-from executorch.exir.passes.dim_order_ops_registry import (
- DimOrderOpsMap,
- MemoryFormatOpsMap,
-)
from torch._subclasses import FakeTensor
from torch.fx.node import Argument
@@ -1805,72 +1799,6 @@ def call_operator(
)
-@register_cadence_pass(CadencePassAttribute(opt_level=0))
-class ReplaceToDimOrderCopyWithToCopyPass(ExportPass):
- """
- dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass.
- If the dim order is sequential, we don't need the extra work with strides and
- can just use to_copy.
- """
-
- def call_operator(
- self,
- op,
- args: Tuple[Argument, ...],
- kwargs: Dict[str, Argument],
- meta: NodeMetadata,
- ) -> ProxyValue:
- if op not in DimOrderOpsMap:
- return super().call_operator(op, args, kwargs, meta)
-
- # new kwargs with dim_order, and no memory_format for the new op
- nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
-
- ndim = None
-
- # can always get the shape, assuming rank is specialized
-
- # pyre-ignore[16]: `None` has no attribute `to_tensor`
- if isinstance(args[0], ProxyValue) and args[0].is_tensor():
- # pyre-ignore[16]: `None` has no attribute `to_tensor`
- ndim = args[0].to_tensor().dim()
- elif isinstance(args[0], torch.Tensor):
- # pyre-ignore[16]: `None` has no attribute `dim`
- ndim = args[0].dim()
- elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
- # pyre-ignore[6]: Incompatible parameter type
- ndim = len(args[0])
- else:
- assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"
-
- # get the "to" memory format for the EdgeOp
- contiguous_dim_order = list(range(ndim))
- dim_order = nkwargs.pop("dim_order", None)
-
- # Cadence only supports contiguous memory format
- assert (
- dim_order is None
- # pyre-ignore[6]: Incompatible parameter type
- or len(dim_order) == 0
- or dim_order == contiguous_dim_order
- ), "Expected dim order in congituous or prevserve memory format, but got {}".format(
- dim_order
- )
-
- # bring back memory format
- # pyre-ignore[6]: Incompatible parameter type
- nkwargs["memory_format"] = get_memory_format(dim_order)
-
- memory_format_op = MemoryFormatOpsMap[op]
-
- return super().call_operator(
- memory_format_op,
- args,
- nkwargs,
- meta,
- )
-
-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceFullLikeWithFullPass(ExportPass):
"""
@@ -2180,5 +2108,4 @@ class CadenceReplaceOpsInGraph:
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
- ReplaceToDimOrderCopyWithToCopyPass,
]
diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp
index d51fee5338..409c4cc510 100644
--- a/backends/cadence/fusion_g3/operators/op_add.cpp
+++ b/backends/cadence/fusion_g3/operators/op_add.cpp
@@ -35,21 +35,7 @@ Tensor& add_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (canCast(common_type, out.scalar_type()) &&
- torch::executor::check_alpha_type(
- torch::executor::native::utils::get_scalar_dtype(alpha),
- common_type)),
- InvalidArgument,
- out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -65,10 +51,6 @@ Tensor& add_out(
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
-
static constexpr const char op_name[] = "add.out";
int kTensorDimensionLimit = 5;
@@ -77,12 +59,12 @@ Tensor& add_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];
- bool broadcast = 0;
+ bool broadcast = false;
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
- bool optimized = 1;
+ bool optimized = true;
/* Added change to work with input dimensions more than 5 */
for (int i = 0; i < max_dim; i++) {
@@ -109,15 +91,19 @@ Tensor& add_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
- broadcast = 1;
+ broadcast = true;
}
}
- if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
- optimized = 0;
+ if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
+ (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == b.scalar_type()) &&
+ (a.scalar_type() == out.scalar_type())))) {
+ optimized = false;
}
- if ((compute_type == ScalarType::Int) && (optimized)) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
const int* const inp2_data = b.const_data_ptr();
int* const out_data = out.mutable_data_ptr();
@@ -169,7 +155,7 @@ Tensor& add_out(
alpha_val,
out.numel());
}
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
const float* const inp2_data = b.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
@@ -222,6 +208,23 @@ Tensor& add_out(
out.numel());
}
} else {
+ // Common Dtype
+ ScalarType common_type =
+ executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (canCast(common_type, out.scalar_type()) &&
+ torch::executor::check_alpha_type(
+ torch::executor::native::utils::get_scalar_dtype(alpha),
+ common_type)),
+ InvalidArgument,
+ out);
+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to(alpha);
@@ -249,22 +252,7 @@ Tensor& add_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- torch::executor::native::utils::promote_type_with_scalar(
- a.scalar_type(), b);
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (common_type == out.scalar_type() &&
- torch::executor::check_alpha_type(
- torch::executor::native::utils::get_scalar_dtype(alpha),
- common_type)),
- InvalidArgument,
- out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -279,14 +267,23 @@ Tensor& add_scalar_out(
InvalidArgument,
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "add.Scalar_out";
- if (compute_type == ScalarType::Int) {
+ bool optimized = true;
+
+ if (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == out.scalar_type()))) {
+ optimized = false;
+ }
+
+ if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
+ optimized = false;
+ }
+
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -306,7 +303,7 @@ Tensor& add_scalar_out(
alpha_val,
out.numel());
- } else if (compute_type == ScalarType::Float) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -327,6 +324,24 @@ Tensor& add_scalar_out(
out.numel());
} else {
+ // Common Dtype
+ ScalarType common_type =
+ torch::executor::native::utils::promote_type_with_scalar(
+ a.scalar_type(), b);
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (common_type == out.scalar_type() &&
+ torch::executor::check_alpha_type(
+ torch::executor::native::utils::get_scalar_dtype(alpha),
+ common_type)),
+ InvalidArgument,
+ out);
+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_unitensor_elementwise_fn(
diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp
index 74fd96a212..84224b37b0 100644
--- a/backends/cadence/fusion_g3/operators/op_cat.cpp
+++ b/backends/cadence/fusion_g3/operators/op_cat.cpp
@@ -46,11 +46,6 @@ Tensor& cat_out(
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_cat_args(tensors, dim, out),
- InvalidArgument,
- out);
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
@@ -106,7 +101,16 @@ Tensor& cat_out(
out_shapes[i] = out_size[i];
}
- if ((out.scalar_type() == ScalarType::Int) ||
+ bool optimized = true;
+
+ for (int i = 0; i < tensors.size(); i++) {
+ if (out.scalar_type() != tensors[i].scalar_type()) {
+ optimized = false;
+ break;
+ }
+ }
+
+ if ((optimized) && (out.scalar_type() == ScalarType::Int) ||
(out.scalar_type() == ScalarType::Short) ||
(out.scalar_type() == ScalarType::Char) ||
(out.scalar_type() == ScalarType::UInt32) ||
@@ -125,6 +129,12 @@ Tensor& cat_out(
(int)dim,
get_element_size(out.scalar_type()));
} else {
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_cat_args(tensors, dim, out),
+ InvalidArgument,
+ out);
+
const size_t outer = executorch::runtime::getLeadingDims(out, dim);
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
const size_t ninputs = tensors.size();
diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp
index 3e0235170b..dd9d4f2a51 100644
--- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp
+++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp
@@ -117,7 +117,7 @@ Tensor& dequantize_impl(
}
}
} else {
- if (*zero_point_data != 0) // tesor
+ if (*zero_point_data != 0) // tensor
{
is_asym_dequant |= 1;
}
@@ -125,8 +125,14 @@ Tensor& dequantize_impl(
}
float* out_data = out.mutable_data_ptr();
+ bool optimized = true;
+
+ if (out.scalar_type() != ScalarType::Float) {
+ optimized = false;
+ }
+
if (is_asym_dequant) {
- if (input.scalar_type() == ScalarType::Byte) {
+ if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -139,7 +145,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
- } else if (input.scalar_type() == ScalarType::Char) {
+ } else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -152,7 +158,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
- } else if (input.scalar_type() == ScalarType::UInt16) {
+ } else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -165,7 +171,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
- } else if (input.scalar_type() == ScalarType::Short) {
+ } else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -178,7 +184,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
- } else if (input.scalar_type() == (ScalarType)Bits4u) {
+ } else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -191,7 +197,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
- } else if (input.scalar_type() == (ScalarType)Bits4) {
+ } else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -338,7 +344,7 @@ Tensor& dequantize_impl(
}
}
} else {
- if (input.scalar_type() == ScalarType::Byte) {
+ if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -350,7 +356,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
- } else if (input.scalar_type() == ScalarType::Char) {
+ } else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -362,7 +368,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
- } else if (input.scalar_type() == ScalarType::UInt16) {
+ } else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -374,7 +380,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
- } else if (input.scalar_type() == ScalarType::Short) {
+ } else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -386,7 +392,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
- } else if (input.scalar_type() == (ScalarType)Bits4u) {
+ } else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -398,7 +404,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
- } else if (input.scalar_type() == (ScalarType)Bits4) {
+ } else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr();
XT_KERNEL_CHECK(
ctx,
diff --git a/backends/cadence/fusion_g3/operators/op_div.cpp b/backends/cadence/fusion_g3/operators/op_div.cpp
index 1461f643a8..85e5da4276 100644
--- a/backends/cadence/fusion_g3/operators/op_div.cpp
+++ b/backends/cadence/fusion_g3/operators/op_div.cpp
@@ -54,10 +54,6 @@ Tensor& div_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
-
#ifdef OP_ARG_CHECK
// Check Dim Order
ET_KERNEL_CHECK(
@@ -73,11 +69,6 @@ Tensor& div_out(
InvalidArgument,
out);
#endif
-
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
-
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.out";
@@ -87,12 +78,12 @@ Tensor& div_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];
- bool broadcast = 0;
+ bool broadcast = false;
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
- bool optimized = 1;
+ bool optimized = true;
for (int i = 0; i < max_dim; i++) {
out_shape[i] = 1;
@@ -118,15 +109,19 @@ Tensor& div_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
- broadcast = 1;
+ broadcast = true;
}
}
- if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
- optimized = 0;
+ if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
+ (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == b.scalar_type()) &&
+ (out.scalar_type() == ScalarType::Float)))) {
+ optimized = false;
}
- if ((compute_type == ScalarType::Int) && (optimized)) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
const int* const inp2_data = b.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
@@ -162,7 +157,7 @@ Tensor& div_out(
inp2_data,
out.numel());
}
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
const float* const inp2_data = b.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
@@ -244,19 +239,7 @@ Tensor& div_out_mode(
ET_KERNEL_CHECK(
ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
- // Common Dtype
- ScalarType common_type =
- executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (canCast(common_type, out.scalar_type()) &&
- common_type != ScalarType::Bool),
- InvalidArgument,
- out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -271,9 +254,6 @@ Tensor& div_out_mode(
InvalidArgument,
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.out_mode";
@@ -287,12 +267,12 @@ Tensor& div_out_mode(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];
- bool broadcast = 0;
+ bool broadcast = false;
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
- bool optimized = 1;
+ bool optimized = true;
for (int i = 0; i < max_dim; i++) {
out_shape[i] = 1;
@@ -318,17 +298,21 @@ Tensor& div_out_mode(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
- broadcast = 1;
+ broadcast = true;
}
}
- if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
- optimized = 0;
+ if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
+ (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == b.scalar_type()) &&
+ (a.scalar_type() == out.scalar_type())))) {
+ optimized = false;
}
int mode_value = (mode_val == "trunc") ? 1 : 2;
- if ((compute_type == ScalarType::Int) && (optimized)) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
const int* const inp2_data = b.const_data_ptr();
int* const out_data = out.mutable_data_ptr();
@@ -367,7 +351,7 @@ Tensor& div_out_mode(
mode_value,
out.numel());
}
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
const float* const inp2_data = b.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
@@ -407,6 +391,21 @@ Tensor& div_out_mode(
out.numel());
}
} else {
+ // Common Dtype
+ ScalarType common_type =
+ executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (canCast(common_type, out.scalar_type()) &&
+ common_type != ScalarType::Bool),
+ InvalidArgument,
+ out);
+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_bitensor_elementwise_fn(
@@ -456,15 +455,7 @@ Tensor& div_scalar_out(
const Tensor& a,
const Scalar& b,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- torch::executor::native::utils::promote_type_with_scalar(
- a.scalar_type(), b);
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -480,14 +471,22 @@ Tensor& div_scalar_out(
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
+ bool optimized = true;
+
+ if (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (out.scalar_type() == ScalarType::Float))) {
+ optimized = false;
+ }
+
+ if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
+ optimized = false;
+ }
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.Scalar_out";
- if (compute_type == ScalarType::Int) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -502,7 +501,7 @@ Tensor& div_scalar_out(
inp1_data,
inp2_val,
out.numel());
- } else if (compute_type == ScalarType::Float) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -526,6 +525,11 @@ Tensor& div_scalar_out(
: ScalarType::Float;
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx, common_type == out.scalar_type(), InvalidArgument, out);
+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b =
torch::executor::native::utils::scalar_to(b);
@@ -560,29 +564,7 @@ Tensor& div_scalar_mode_out(
ET_KERNEL_CHECK(
ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out);
- // Common Dtype
- ScalarType common_type =
- torch::executor::native::utils::promote_type_with_scalar(
- a.scalar_type(), b);
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (canCast(common_type, out.scalar_type()) &&
- common_type != ScalarType::Bool),
- InvalidArgument,
- out);
-
- // Check for intergral division by zero
- ET_KERNEL_CHECK_MSG(
- ctx,
- !(executorch::runtime::isIntegralType(common_type, true) &&
- torch::executor::native::utils::scalar_to(b) == 0),
- InvalidArgument,
- out,
- "Div mode operation encountered integer division by zero");
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -598,18 +580,26 @@ Tensor& div_scalar_mode_out(
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
-
const bool mode_is_trunc = mode_val == "trunc";
+ bool optimized = true;
+
+ if (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == out.scalar_type()))) {
+ optimized = false;
+ }
+
+ if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
+ optimized = false;
+ }
+
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.Scalar_mode_out";
int mode_value = (mode_val == "trunc") ? 1 : 2;
- if (compute_type == ScalarType::Int) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -625,7 +615,7 @@ Tensor& div_scalar_mode_out(
inp2_val,
mode_value,
out.numel());
- } else if (compute_type == ScalarType::Float) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -642,6 +632,31 @@ Tensor& div_scalar_mode_out(
mode_value,
out.numel());
} else {
+ // Common Dtype
+ ScalarType common_type =
+ torch::executor::native::utils::promote_type_with_scalar(
+ a.scalar_type(), b);
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (canCast(common_type, out.scalar_type()) &&
+ common_type != ScalarType::Bool),
+ InvalidArgument,
+ out);
+
+ // Check for intergral division by zero
+ ET_KERNEL_CHECK_MSG(
+ ctx,
+ !(executorch::runtime::isIntegralType(common_type, true) &&
+ torch::executor::native::utils::scalar_to(b) == 0),
+ InvalidArgument,
+ out,
+ "Div mode operation encountered integer division by zero");
+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b =
torch::executor::native::utils::scalar_to(b);
diff --git a/backends/cadence/fusion_g3/operators/op_exp.cpp b/backends/cadence/fusion_g3/operators/op_exp.cpp
index 4b6b898b17..41b5d70b22 100644
--- a/backends/cadence/fusion_g3/operators/op_exp.cpp
+++ b/backends/cadence/fusion_g3/operators/op_exp.cpp
@@ -49,9 +49,10 @@ Tensor& exp_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
out);
#endif
- if (in.scalar_type() == ScalarType::Float) {
- float* __restrict__ out_data = out.mutable_data_ptr();
- const float* __restrict__ in_data = in.const_data_ptr();
+ if ((in.scalar_type() == ScalarType::Float) &&
+ (out.scalar_type() == ScalarType::Float)) {
+ float* const out_data = out.mutable_data_ptr();
+ const float* const in_data = in.const_data_ptr();
XT_KERNEL_CHECK(
ctx, out, xa_nn_elm_exp_f32_f32, out_data, in_data, out.numel());
diff --git a/backends/cadence/fusion_g3/operators/op_mean.cpp b/backends/cadence/fusion_g3/operators/op_mean.cpp
index 289baceb12..ae0cfd1e27 100644
--- a/backends/cadence/fusion_g3/operators/op_mean.cpp
+++ b/backends/cadence/fusion_g3/operators/op_mean.cpp
@@ -44,15 +44,16 @@ int prepare_data(
for (int i = 0; i < num_out_dims; i++) {
out_shape[i] = out.size(i);
}
-
int num_axis_dims = 0;
- for (const auto& d : dim_list.value()) {
- if (d < 0) {
- p_axis[num_axis_dims] = num_inp_dims + d;
- num_axis_dims++;
- } else {
- p_axis[num_axis_dims] = d;
- num_axis_dims++;
+ if (dim_list.has_value()) {
+ for (const auto& d : dim_list.value()) {
+ if (d < 0) {
+ p_axis[num_axis_dims] = num_inp_dims + d;
+ num_axis_dims++;
+ } else {
+ p_axis[num_axis_dims] = d;
+ num_axis_dims++;
+ }
}
}
@@ -69,12 +70,6 @@ Tensor& mean_out(
(void)ctx;
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out),
- InvalidArgument,
- out);
-
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensors_have_same_dim_order(in, out),
@@ -97,13 +92,14 @@ Tensor& mean_out(
constexpr int kNnlibMaxDim = 5;
- bool optimized = 1;
+ bool optimized = true;
- if (out.scalar_type() != ScalarType::Float)
- optimized = 0;
+ if (!((out.scalar_type() == ScalarType::Float) &&
+ (in.scalar_type() == ScalarType::Float)))
+ optimized = false;
if (in.dim() > kNnlibMaxDim)
- optimized = 0;
+ optimized = false;
if (optimized) {
float* __restrict__ p_out = out.mutable_data_ptr();
@@ -135,9 +131,8 @@ Tensor& mean_out(
num_inp_dims,
num_out_dims);
- if (num_axis_dims == num_inp_dims) {
+ if ((num_axis_dims == num_inp_dims) || (!dim_list.has_value())) {
num_out_dims = 1;
- out_shape[0] = 1;
}
int inp_shape_max = inp_shape[p_axis[0]];
@@ -168,29 +163,38 @@ Tensor& mean_out(
num_axis_dims,
p_scratch_in);
} else {
- ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
- ET_SWITCH_FLOATH_TYPES(
- out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
- CTYPE_OUT* out_data = out.mutable_data_ptr();
- const size_t num =
- torch::executor::get_reduced_dim_product(in, dim_list);
- for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
- CTYPE_OUT sum = 0;
- if (in.numel() > 0) {
- sum = torch::executor::
- map_reduce_over_dim_list(
- [](CTYPE_IN v) { return static_cast(v); },
- [](CTYPE_OUT outv, CTYPE_OUT acc) {
- return acc + outv;
- },
- in,
- dim_list,
- out_ix);
- }
- out_data[out_ix] = sum / static_cast(num);
- }
- });
- });
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_mean_dim_args(in, dim_list, keepdim, dtype, out),
+ InvalidArgument,
+ out);
+
+ ET_SWITCH_REALHBBF16_TYPES(
+ in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
+ ET_SWITCH_FLOATHBF16_TYPES(
+ out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
+ CTYPE_OUT* out_data = out.mutable_data_ptr();
+ const size_t num =
+ torch::executor::get_reduced_dim_product(in, dim_list);
+ for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
+ CTYPE_OUT sum = 0;
+ if (in.numel() > 0) {
+ sum = torch::executor::
+ map_reduce_over_dim_list(
+ [](CTYPE_IN v) {
+ return static_cast(v);
+ },
+ [](CTYPE_OUT outv, CTYPE_OUT acc) {
+ return acc + outv;
+ },
+ in,
+ dim_list,
+ out_ix);
+ }
+ out_data[out_ix] = sum / static_cast(num);
+ }
+ });
+ });
}
return out;
diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp
index 93b4c5a992..bee6ac9cbd 100644
--- a/backends/cadence/fusion_g3/operators/op_mul.cpp
+++ b/backends/cadence/fusion_g3/operators/op_mul.cpp
@@ -33,15 +33,7 @@ Tensor& mul_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -57,10 +49,6 @@ Tensor& mul_out(
out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
-
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.out";
int kTensorDimensionLimit = 5;
@@ -69,12 +57,12 @@ Tensor& mul_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];
- bool broadcast = 0;
+ bool broadcast = false;
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
- bool optimized = 1;
+ bool optimized = true;
/* Added change to work with input dimensions more than 5 */
for (int i = 0; i < max_dim; i++) {
@@ -101,15 +89,19 @@ Tensor& mul_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
- broadcast = 1;
+ broadcast = true;
}
}
- if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
- optimized = 0;
+ if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
+ (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == b.scalar_type()) &&
+ (a.scalar_type() == out.scalar_type())))) {
+ optimized = false;
}
- if ((compute_type == ScalarType::Int) && (optimized)) {
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
const int* const inp2_data = b.const_data_ptr();
int* const out_data = out.mutable_data_ptr();
@@ -154,7 +146,7 @@ Tensor& mul_out(
inp2_data,
out.numel());
}
- } else if ((compute_type == ScalarType::Float) && (optimized)) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
const float* const inp2_data = b.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
@@ -200,6 +192,16 @@ Tensor& mul_out(
out.numel());
}
} else {
+ // Common Dtype
+ ScalarType common_type =
+ executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
@@ -224,15 +226,7 @@ Tensor& mul_scalar_out(
const Tensor& a,
const Scalar& b,
Tensor& out) {
- // Common Dtype
- ScalarType common_type =
- torch::executor::native::utils::promote_type_with_scalar(
- a.scalar_type(), b);
-
#ifdef OP_ARG_CHECK
- // Check Common Dtype
- ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -244,13 +238,23 @@ Tensor& mul_scalar_out(
ET_KERNEL_CHECK(
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
#endif
- // Compute Dtype
- ScalarType compute_type =
- torch::executor::native::utils::get_compute_type(common_type);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.Scalar_out";
- if (compute_type == ScalarType::Int) {
+
+ bool optimized = true;
+
+ if (!(((a.scalar_type() == ScalarType::Int) ||
+ (a.scalar_type() == ScalarType::Float)) &&
+ (a.scalar_type() == out.scalar_type()))) {
+ optimized = false;
+ }
+
+ if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
+ optimized = false;
+ }
+
+ if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -264,7 +268,7 @@ Tensor& mul_scalar_out(
inp1_data,
inp2_val,
out.numel());
- } else if (compute_type == ScalarType::Float) {
+ } else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
@@ -279,6 +283,17 @@ Tensor& mul_scalar_out(
inp2_val,
out.numel());
} else {
+ // Common Dtype
+ ScalarType common_type =
+ torch::executor::native::utils::promote_type_with_scalar(
+ a.scalar_type(), b);
+ // Compute Dtype
+ ScalarType compute_type =
+ torch::executor::native::utils::get_compute_type(common_type);
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx, common_type == out.scalar_type(), InvalidArgument, out);
+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b =
torch::executor::native::utils::scalar_to(b);
diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp
index 9857bbce37..b4f076e810 100644
--- a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp
+++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp
@@ -123,14 +123,7 @@ std::tuple native_layer_norm_out(
std::tuple ret_val(out, mean_out, rstd_out);
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
-
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_layer_norm_args(
- input, normalized_shape, weight, bias, out, mean_out, rstd_out),
- InvalidArgument,
- ret_val);
// Only support default dim order for now.
// TODO: Support other dim orders.
@@ -189,12 +182,34 @@ std::tuple native_layer_norm_out(
ret_val);
#endif
+ bool optimized = true;
+
int input_shape[kTensorDimensionLimit];
for (int i = 0; i < input.dim(); i++) {
input_shape[i] = input.size(i);
}
- if (out.scalar_type() == ScalarType::Float) {
+ if (!(((input.scalar_type() == ScalarType::Float) &&
+ (input.scalar_type() == out.scalar_type()) &&
+ (out.scalar_type() == mean_out.scalar_type()) &&
+ (mean_out.scalar_type() == rstd_out.scalar_type())))) {
+ optimized = false;
+ }
+
+ if (optimized) {
+ if (weight.has_value()) {
+ if (!(input.scalar_type() == weight.value().scalar_type())) {
+ optimized = false;
+ }
+ }
+ if (bias.has_value()) {
+ if (!(input.scalar_type() == bias.value().scalar_type())) {
+ optimized = false;
+ }
+ }
+ }
+
+ if ((input.scalar_type() == ScalarType::Float) && (optimized)) {
float* const out_data = out.mutable_data_ptr();
float* const mean_data = mean_out.mutable_data_ptr();
float* const rstd_data = rstd_out.mutable_data_ptr();
@@ -247,6 +262,13 @@ std::tuple native_layer_norm_out(
free(weight_data);
}
} else {
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_layer_norm_args(
+ input, normalized_shape, weight, bias, out, mean_out, rstd_out),
+ InvalidArgument,
+ ret_val);
+
ET_SWITCH_FLOAT_TYPES(
input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() {
layer_norm(
diff --git a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp
index 23c2d1e5fb..34def4fd1b 100644
--- a/backends/cadence/fusion_g3/operators/op_permute_copy.cpp
+++ b/backends/cadence/fusion_g3/operators/op_permute_copy.cpp
@@ -65,12 +65,6 @@ Tensor& permute_copy_out(
* the checks only in operator level(As there are no checks in kernel).
*/
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_permute_copy_args(in, dims, out),
- InvalidArgument,
- out);
-
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensors_have_same_dim_order(in, out),
@@ -112,7 +106,8 @@ Tensor& permute_copy_out(
signed char* out_data = out.mutable_data_ptr();
const signed char* const inp_data = in.const_data_ptr();
- if (((out.scalar_type() == ScalarType::Int) ||
+ if (((out.scalar_type() == in.scalar_type()) &&
+ (out.scalar_type() == ScalarType::Int) ||
(out.scalar_type() == ScalarType::Short) ||
(out.scalar_type() == ScalarType::Char) ||
(out.scalar_type() == ScalarType::UInt32) ||
@@ -131,9 +126,15 @@ Tensor& permute_copy_out(
in.dim(),
get_element_size(out.scalar_type()));
} else {
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_permute_copy_args(in, dims, out),
+ InvalidArgument,
+ out);
+
const auto in_type = out.scalar_type();
- size_t in_coord[5] = {0};
- size_t trailing_dims_memo[kTensorDimensionLimit];
+ size_t in_coord[executorch::runtime::kTensorDimensionLimit] = {0};
+ size_t trailing_dims_memo[executorch::runtime::kTensorDimensionLimit];
executorch::runtime::memoizeTrailingDims(in, trailing_dims_memo);
// in and out must be the same dtype
ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] {
diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp
index 8237c3c266..2af77eca6c 100644
--- a/backends/cadence/fusion_g3/operators/op_quantize.cpp
+++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp
@@ -159,6 +159,12 @@ Tensor& quantize_impl(
bool is_asym_quant = 0;
+ bool optimized = true;
+
+ if (input.scalar_type() != ScalarType::Float) {
+ optimized = false;
+ }
+
if (zero_point_data != NULL) // asymmetric quant
{
if (axis != NULL) // channel
@@ -177,7 +183,7 @@ Tensor& quantize_impl(
}
if (is_asym_quant) {
- if (out.scalar_type() == ScalarType::Byte) {
+ if ((out.scalar_type() == ScalarType::Byte) && (optimized)) {
uint8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -192,7 +198,7 @@ Tensor& quantize_impl(
zero_point_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::Char) {
+ } else if ((out.scalar_type() == ScalarType::Char) && (optimized)) {
int8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
@@ -208,7 +214,7 @@ Tensor& quantize_impl(
zero_point_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::UInt16) {
+ } else if ((out.scalar_type() == ScalarType::UInt16) && (optimized)) {
uint16_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -223,7 +229,7 @@ Tensor& quantize_impl(
zero_point_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::Short) {
+ } else if ((out.scalar_type() == ScalarType::Short) && (optimized)) {
int16_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -238,7 +244,7 @@ Tensor& quantize_impl(
zero_point_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == (ScalarType)Bits4u) {
+ } else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
uint8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -253,7 +259,7 @@ Tensor& quantize_impl(
zero_point_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == (ScalarType)Bits4) {
+ } else if ((out.scalar_type() == (ScalarType)Bits4) && (optimized)) {
int8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -391,7 +397,7 @@ Tensor& quantize_impl(
#undef ASYM_QUANTIZE_IMPL_CHANNEL
}
} else {
- if (out.scalar_type() == ScalarType::Byte) {
+ if ((out.scalar_type() == ScalarType::Byte) && (optimized)) {
uint8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -405,7 +411,7 @@ Tensor& quantize_impl(
scale_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::Char) {
+ } else if ((out.scalar_type() == ScalarType::Char) && (optimized)) {
int8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -419,7 +425,7 @@ Tensor& quantize_impl(
scale_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::UInt16) {
+ } else if ((out.scalar_type() == ScalarType::UInt16) && (optimized)) {
uint16_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -433,7 +439,7 @@ Tensor& quantize_impl(
scale_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == ScalarType::Short) {
+ } else if ((out.scalar_type() == ScalarType::Short) && (optimized)) {
int16_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -447,7 +453,7 @@ Tensor& quantize_impl(
scale_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == (ScalarType)Bits4u) {
+ } else if ((out.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
uint8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
@@ -461,7 +467,7 @@ Tensor& quantize_impl(
scale_data,
quant_min,
quant_max);
- } else if (out.scalar_type() == (ScalarType)Bits4) {
+ } else if ((out.scalar_type() == (ScalarType)Bits4) && (optimized)) {
int8_t* out_data = out.mutable_data_ptr();
XT_KERNEL_CHECK(
ctx,
diff --git a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp
index c481cf726b..9158eecf13 100644
--- a/backends/cadence/fusion_g3/operators/op_slice_copy.cpp
+++ b/backends/cadence/fusion_g3/operators/op_slice_copy.cpp
@@ -58,12 +58,6 @@ Tensor& slice_copy_Tensor_out(
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_slice_copy_args(in, dim, step, out),
- InvalidArgument,
- out);
-
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensors_have_same_dim_order(in, out),
@@ -101,12 +95,13 @@ Tensor& slice_copy_Tensor_out(
signed char* out_data = out.mutable_data_ptr();
const signed char* const inp_data = in.const_data_ptr();
- if ((out.scalar_type() == ScalarType::Int) ||
- (out.scalar_type() == ScalarType::Short) ||
- (out.scalar_type() == ScalarType::Char) ||
- (out.scalar_type() == ScalarType::UInt32) ||
- (out.scalar_type() == ScalarType::UInt16) ||
- (out.scalar_type() == ScalarType::Byte)) {
+ if ((out.scalar_type() == in.scalar_type()) &&
+ ((out.scalar_type() == ScalarType::Int) ||
+ (out.scalar_type() == ScalarType::Short) ||
+ (out.scalar_type() == ScalarType::Char) ||
+ (out.scalar_type() == ScalarType::UInt32) ||
+ (out.scalar_type() == ScalarType::UInt16) ||
+ (out.scalar_type() == ScalarType::Byte))) {
XT_KERNEL_CHECK(
ctx,
out,
@@ -122,6 +117,12 @@ Tensor& slice_copy_Tensor_out(
(int)dim,
get_element_size(out.scalar_type()));
} else {
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_slice_copy_args(in, dim, step, out),
+ InvalidArgument,
+ out);
+
torch::executor::compute_slice(in, dim, start, length, step, out);
}
diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp
index ee87ebaf5a..14b128e928 100644
--- a/backends/cadence/fusion_g3/operators/op_softmax.cpp
+++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp
@@ -39,14 +39,7 @@ Tensor& _softmax_out(
// Adjust for negative dim
dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim;
-
#ifdef OP_ARG_CHECK
- ET_KERNEL_CHECK(
- ctx,
- torch::executor::check_softmax_args(in, dim, half_to_float, out),
- InvalidArgument,
- out);
-
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
@@ -63,7 +56,8 @@ Tensor& _softmax_out(
inp_shapes[i] = in_size[i];
}
- if (out.scalar_type() == ScalarType::Float) {
+ if ((in.scalar_type() == ScalarType::Float) &&
+ (out.scalar_type() == ScalarType::Float)) {
const float* const inp_data = in.const_data_ptr();
float* const out_data = out.mutable_data_ptr();
int axis = dim;
@@ -77,6 +71,12 @@ Tensor& _softmax_out(
in.dim(),
&axis);
} else {
+ ET_KERNEL_CHECK(
+ ctx,
+ torch::executor::check_softmax_args(in, dim, half_to_float, out),
+ InvalidArgument,
+ out);
+
ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr();
CTYPE* const out_data = out.mutable_data_ptr();
diff --git a/backends/cadence/fusion_g3/operators/op_sub.cpp b/backends/cadence/fusion_g3/operators/op_sub.cpp
index 4bae81c5b2..9bafec5df9 100644
--- a/backends/cadence/fusion_g3/operators/op_sub.cpp
+++ b/backends/cadence/fusion_g3/operators/op_sub.cpp
@@ -35,19 +35,6 @@ Tensor& sub_out(
const Scalar& alpha,
Tensor& out) {
#ifdef OP_ARG_CHECK
- ScalarType alpha_type =
- torch::executor::native::utils::get_scalar_dtype(alpha);
- // Check alpha type
- ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
-
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (canCast(common_type, out.scalar_type()) &&
- canCast(alpha_type, common_type)),
- InvalidArgument,
- out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -72,12 +59,12 @@ Tensor& sub_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];
- bool broadcast = 0;
+ bool broadcast = false;
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
- bool optimized = 1;
+ bool optimized = true;
for (int i = 0; i < max_dim; i++) {
out_shape[i] = 1;
@@ -103,16 +90,16 @@ Tensor& sub_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
- broadcast = 1;
+ broadcast = true;
}
}
- if (((broadcast == 1) && (max_dim > kTensorDimensionLimit)) ||
+ if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
(!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == b.scalar_type()) &&
(a.scalar_type() == out.scalar_type())))) {
- optimized = 0;
+ optimized = false;
}
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
@@ -207,6 +194,19 @@ Tensor& sub_out(
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);
+ ScalarType alpha_type =
+ torch::executor::native::utils::get_scalar_dtype(alpha);
+ // Check alpha type
+ ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (canCast(common_type, out.scalar_type()) &&
+ canCast(alpha_type, common_type)),
+ InvalidArgument,
+ out);
+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to(alpha);
@@ -236,18 +236,6 @@ Tensor& sub_scalar_out(
const Scalar& alpha,
Tensor& out) {
#ifdef OP_ARG_CHECK
- ScalarType alpha_type =
- torch::executor::native::utils::get_scalar_dtype(alpha);
- // Check alpha type
- ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
-
- // Check Common Dtype
- ET_KERNEL_CHECK(
- ctx,
- (common_type == out.scalar_type() && canCast(alpha_type, common_type)),
- InvalidArgument,
- out);
-
// Check Dim Order
ET_KERNEL_CHECK(
ctx,
@@ -266,14 +254,16 @@ Tensor& sub_scalar_out(
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "sub.Scalar_out";
- bool optimized = 1;
- ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b);
+ bool optimized = true;
if (!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
- (a.scalar_type() == b_type) &&
(a.scalar_type() == out.scalar_type()))) {
- optimized = 0;
+ optimized = false;
+ }
+
+ if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
+ optimized = false;
}
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
@@ -322,6 +312,19 @@ Tensor& sub_scalar_out(
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);
+
+ ScalarType alpha_type =
+ torch::executor::native::utils::get_scalar_dtype(alpha);
+ // Check alpha type
+ ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
+
+ // Check Common Dtype
+ ET_KERNEL_CHECK(
+ ctx,
+ (common_type == out.scalar_type() && canCast(alpha_type, common_type)),
+ InvalidArgument,
+ out);
+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b =
torch::executor::native::utils::scalar_to(b);
diff --git a/backends/cadence/hifi/operators/op_add.cpp b/backends/cadence/hifi/operators/op_add.cpp
index ec0e48e379..3a590ea071 100644
--- a/backends/cadence/hifi/operators/op_add.cpp
+++ b/backends/cadence/hifi/operators/op_add.cpp
@@ -16,9 +16,9 @@
#include
#include
-using exec_aten::Scalar;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
+using executorch::aten::Scalar;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::CppTypeToScalarType;
using executorch::runtime::KernelRuntimeContext;
diff --git a/backends/cadence/hifi/operators/op_cat.cpp b/backends/cadence/hifi/operators/op_cat.cpp
index e367d71b79..8ad52753de 100644
--- a/backends/cadence/hifi/operators/op_cat.cpp
+++ b/backends/cadence/hifi/operators/op_cat.cpp
@@ -30,7 +30,7 @@ namespace native {
Tensor& cat_out(
RuntimeContext& ctx,
- exec_aten::ArrayRef tensors,
+ executorch::aten::ArrayRef tensors,
int64_t dim,
Tensor& out) {
if (dim < 0) {
diff --git a/backends/cadence/hifi/operators/op_clamp.cpp b/backends/cadence/hifi/operators/op_clamp.cpp
index d31161a7d5..4fa29c00dd 100644
--- a/backends/cadence/hifi/operators/op_clamp.cpp
+++ b/backends/cadence/hifi/operators/op_clamp.cpp
@@ -51,8 +51,8 @@ namespace native {
Tensor& clamp_tensor_out(
RuntimeContext& ctx,
const Tensor& in,
- const exec_aten::optional& min_opt,
- const exec_aten::optional& max_opt,
+ const executorch::aten::optional& min_opt,
+ const executorch::aten::optional& max_opt,
Tensor& out) {
(void)ctx;
diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp
similarity index 100%
rename from backends/cadence/hifi/operators/dequantize_per_tensor.cpp
rename to backends/cadence/hifi/operators/op_dequantize_per_tensor.cpp
diff --git a/backends/cadence/hifi/operators/op_div.cpp b/backends/cadence/hifi/operators/op_div.cpp
index 05f3db7ec3..816422858b 100644
--- a/backends/cadence/hifi/operators/op_div.cpp
+++ b/backends/cadence/hifi/operators/op_div.cpp
@@ -17,10 +17,10 @@
#include
#include
-using exec_aten::Scalar;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::Scalar;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using torch::executor::Error;
namespace cadence {
@@ -165,7 +165,7 @@ Tensor& div_out_mode(
RuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
- exec_aten::optional mode,
+ executorch::aten::optional mode,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
diff --git a/backends/cadence/hifi/operators/op_maximum.cpp b/backends/cadence/hifi/operators/op_maximum.cpp
index f85d3470e9..592ea3bc1e 100644
--- a/backends/cadence/hifi/operators/op_maximum.cpp
+++ b/backends/cadence/hifi/operators/op_maximum.cpp
@@ -12,9 +12,9 @@
#include
#include
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::canCast;
using executorch::runtime::CppTypeToScalarType;
diff --git a/backends/cadence/hifi/operators/op_mean.cpp b/backends/cadence/hifi/operators/op_mean.cpp
index 342c982a07..82fa7502de 100644
--- a/backends/cadence/hifi/operators/op_mean.cpp
+++ b/backends/cadence/hifi/operators/op_mean.cpp
@@ -56,7 +56,7 @@ int prepare_data(
return num_axis_dims;
}
-Tensor& mean_dim_out(
+Tensor& mean_out(
RuntimeContext& ctx,
const Tensor& in,
optional> dim_list,
diff --git a/backends/cadence/hifi/operators/op_minimum.cpp b/backends/cadence/hifi/operators/op_minimum.cpp
index 6f81ad5c3e..b78ee64882 100644
--- a/backends/cadence/hifi/operators/op_minimum.cpp
+++ b/backends/cadence/hifi/operators/op_minimum.cpp
@@ -12,9 +12,9 @@
#include
#include
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::canCast;
using executorch::runtime::CppTypeToScalarType;
diff --git a/backends/cadence/hifi/operators/op_mul.cpp b/backends/cadence/hifi/operators/op_mul.cpp
index 396833dd1a..b8c3ab7c02 100644
--- a/backends/cadence/hifi/operators/op_mul.cpp
+++ b/backends/cadence/hifi/operators/op_mul.cpp
@@ -15,10 +15,10 @@
#include
#include
-using exec_aten::Scalar;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::Scalar;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::CppTypeToScalarType;
using torch::executor::Error;
diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/op_quantize_per_tensor.cpp
similarity index 100%
rename from backends/cadence/hifi/operators/quantize_per_tensor.cpp
rename to backends/cadence/hifi/operators/op_quantize_per_tensor.cpp
diff --git a/backends/cadence/hifi/operators/quantized_layer_norm.cpp b/backends/cadence/hifi/operators/op_quantized_layer_norm.cpp
similarity index 100%
rename from backends/cadence/hifi/operators/quantized_layer_norm.cpp
rename to backends/cadence/hifi/operators/op_quantized_layer_norm.cpp
diff --git a/backends/cadence/hifi/operators/quantized_linear_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp
similarity index 97%
rename from backends/cadence/hifi/operators/quantized_linear_out.cpp
rename to backends/cadence/hifi/operators/op_quantized_linear_out.cpp
index b8e1d117fb..3d9983b40c 100644
--- a/backends/cadence/hifi/operators/quantized_linear_out.cpp
+++ b/backends/cadence/hifi/operators/op_quantized_linear_out.cpp
@@ -219,7 +219,7 @@ void quantized_linear_out(
int64_t out_zero_point,
__ET_UNUSED const optional& offset,
Tensor& out) {
- if (out.scalar_type() == exec_aten::ScalarType::Byte) {
+ if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
_quantized_linear_asym8u(
in,
weight,
@@ -231,7 +231,7 @@ void quantized_linear_out(
out_zero_point,
offset,
out);
- } else if (out.scalar_type() == exec_aten::ScalarType::Char) {
+ } else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
_quantized_linear_asym8s(
in,
weight,
@@ -261,7 +261,7 @@ void quantized_linear_per_tensor_out(
int64_t out_zero_point,
__ET_UNUSED const optional& offset,
Tensor& out) {
- if (out.scalar_type() == exec_aten::ScalarType::Byte) {
+ if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
_quantized_linear_per_tensor_asym8u(
in,
weight,
@@ -273,7 +273,7 @@ void quantized_linear_per_tensor_out(
out_zero_point,
offset,
out);
- } else if (out.scalar_type() == exec_aten::ScalarType::Char) {
+ } else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
_quantized_linear_per_tensor_asym8s(
in,
weight,
diff --git a/backends/cadence/hifi/operators/quantized_relu_out.cpp b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp
similarity index 98%
rename from backends/cadence/hifi/operators/quantized_relu_out.cpp
rename to backends/cadence/hifi/operators/op_quantized_relu_out.cpp
index d78e555ad1..0860109f7c 100644
--- a/backends/cadence/hifi/operators/quantized_relu_out.cpp
+++ b/backends/cadence/hifi/operators/op_quantized_relu_out.cpp
@@ -45,7 +45,7 @@ void quantized_relu_(
}
}
-void quantized_relu_out(
+void quantized_relu_per_tensor_out(
KernelRuntimeContext& ctx,
const Tensor& input,
const Tensor& in_zero_point,
@@ -100,4 +100,4 @@ void quantized_relu_out(
} // namespace native
} // namespace HiFi
} // namespace impl
-} // namespace cadence
\ No newline at end of file
+} // namespace cadence
diff --git a/backends/cadence/hifi/operators/op_remainder.cpp b/backends/cadence/hifi/operators/op_remainder.cpp
index d8c4a6d2d8..99cd6ad544 100644
--- a/backends/cadence/hifi/operators/op_remainder.cpp
+++ b/backends/cadence/hifi/operators/op_remainder.cpp
@@ -8,6 +8,7 @@
#include
+#include
#include
#include
#include
@@ -15,8 +16,6 @@
#include
#include
-#include "kernels.h"
-
using executorch::aten::RuntimeContext;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
diff --git a/backends/cadence/hifi/operators/op_rsqrt.cpp b/backends/cadence/hifi/operators/op_rsqrt.cpp
index 1cf717988a..885c26723a 100644
--- a/backends/cadence/hifi/operators/op_rsqrt.cpp
+++ b/backends/cadence/hifi/operators/op_rsqrt.cpp
@@ -11,9 +11,9 @@
#include
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
namespace cadence {
namespace impl {
diff --git a/backends/cadence/hifi/operators/op_sigmoid.cpp b/backends/cadence/hifi/operators/op_sigmoid.cpp
index 35321cc27e..872d9255bd 100644
--- a/backends/cadence/hifi/operators/op_sigmoid.cpp
+++ b/backends/cadence/hifi/operators/op_sigmoid.cpp
@@ -14,9 +14,9 @@
#include
#include
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using torch::executor::Error;
namespace cadence {
@@ -24,7 +24,7 @@ namespace impl {
namespace HiFi {
namespace native {
-using Tensor = exec_aten::Tensor;
+using Tensor = executorch::aten::Tensor;
Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
(void)ctx;
diff --git a/backends/cadence/hifi/operators/op_softmax.cpp b/backends/cadence/hifi/operators/op_softmax.cpp
index e026afd2c9..2ef233c9ff 100644
--- a/backends/cadence/hifi/operators/op_softmax.cpp
+++ b/backends/cadence/hifi/operators/op_softmax.cpp
@@ -8,11 +8,11 @@
#include
+#include
#include
#include
#include
#include
-#include "kernels.h"
using executorch::aten::ScalarType;
using executorch::aten::Tensor;
@@ -24,7 +24,7 @@ namespace impl {
namespace HiFi {
namespace native {
-Tensor& softmax_out(
+Tensor& _softmax_out(
KernelRuntimeContext& ctx,
const Tensor& in,
int64_t dim,
@@ -50,7 +50,7 @@ Tensor& softmax_out(
// Adjust for negative dim
dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim;
- const exec_aten::optional& dim_t = dim;
+ const executorch::aten::optional& dim_t = dim;
const size_t d = ET_NORMALIZE_IX(dim_t.value(), in.dim());
const size_t size = in.size(d);
diff --git a/backends/cadence/hifi/operators/op_sub.cpp b/backends/cadence/hifi/operators/op_sub.cpp
index cf10e41435..02c8c60eac 100644
--- a/backends/cadence/hifi/operators/op_sub.cpp
+++ b/backends/cadence/hifi/operators/op_sub.cpp
@@ -16,10 +16,10 @@
#include
#include
-using exec_aten::Scalar;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::Scalar;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::can_cast;
using executorch::runtime::CppTypeToScalarType;
using torch::executor::Error;
diff --git a/backends/cadence/hifi/operators/op_tanh.cpp b/backends/cadence/hifi/operators/op_tanh.cpp
index 13578beb88..3fdd3111ef 100644
--- a/backends/cadence/hifi/operators/op_tanh.cpp
+++ b/backends/cadence/hifi/operators/op_tanh.cpp
@@ -11,9 +11,9 @@
#include
#include
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using torch::executor::Error;
namespace cadence {
diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl
index 6c671a5f24..1c2b481410 100644
--- a/backends/cadence/hifi/operators/targets.bzl
+++ b/backends/cadence/hifi/operators/targets.bzl
@@ -1,243 +1,70 @@
load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
-def define_common_targets():
- """Defines targets that should be shared between fbcode and xplat.
-
- The directory containing this targets.bzl file should also contain both
- TARGETS and BUCK files that call this function.
- """
-
- # Define build targets for all operators registered in the tables above.
- runtime.cxx_library(
- name = "quantize_per_tensor",
- srcs = [
- "quantize_per_tensor.cpp"
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+def define_operator(name: str, deps: list[str] | None = None) -> None:
+ op_name = "op_{}".format(name)
- runtime.cxx_library(
- name = "dequantize_per_tensor",
- srcs = [
- "dequantize_per_tensor.cpp"
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+ # Deps used by all operators.
+ common_deps = [
+ "//executorch/kernels/portable/cpu/util:all_deps",
+ "//executorch/kernels/portable/cpu/pattern:all_deps",
+ "//executorch/runtime/kernel:kernel_includes",
+ "//executorch/kernels/portable/cpu:scalar_utils",
+ "//executorch/backends/cadence/hifi/kernels:kernels",
+ "//executorch/kernels/portable/cpu/util:dtype_util",
+ "//executorch/kernels/portable/cpu/util:elementwise_util",
+ "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
+ ]
+ if deps == None:
+ deps = []
runtime.cxx_library(
- name = "quantized_layer_norm",
- srcs = [
- "quantized_layer_norm.cpp"
- ],
- exported_headers = ["operators.h"],
+ name = op_name,
+ srcs = [op_name + ".cpp"],
platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
visibility = [
"//executorch/backends/cadence/...",
"@EXECUTORCH_CLIENTS",
],
- )
-
- runtime.cxx_library(
- name = "quantized_linear_out",
- srcs = [
- "quantized_linear_out.cpp"
- ],
+ deps = deps + common_deps,
exported_headers = ["operators.h"],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
-
- runtime.cxx_library(
- name = "op_add",
- srcs = [
- "op_add.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions",
- "//executorch/kernels/portable/cpu/util:dtype_util",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
-
-
- runtime.cxx_library(
- name = "op_mul",
- srcs = [
- "op_mul.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/kernels/portable/cpu/util:dtype_util",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
-
- runtime.cxx_library(
- name = "op_sub",
- srcs = [
- "op_sub.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/kernels/portable/cpu/util:dtype_util",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
)
- runtime.cxx_library(
- name = "op_div",
- srcs = [
- "op_div.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/kernels/portable/cpu:scalar_utils",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/kernels/portable/cpu/util:dtype_util",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+OPERATORS = [
+ "add",
+ "atan2",
+ "cat",
+ "clamp",
+ "dequantize_per_tensor",
+ "div",
+ "full",
+ "maximum",
+ "mean",
+ "minimum",
+ "mul",
+ "permute_copy",
+ "pow",
+ "quantize_per_tensor",
+ "quantized_layer_norm",
+ "quantized_linear_out",
+ "quantized_relu_out",
+ "remainder",
+ "rsqrt",
+ "sigmoid",
+ "softmax",
+ "sub",
+ "tanh",
+ "where"
+]
- runtime.cxx_library(
- name = "op_sigmoid",
- srcs = [
- "op_sigmoid.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/kernels/portable/cpu/util:dtype_util",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+def define_common_targets():
+ """Defines targets that should be shared between fbcode and xplat.
- runtime.cxx_library(
- name = "op_tanh",
- srcs = [
- "op_tanh.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+ The directory containing this targets.bzl file should also contain both
+ TARGETS and BUCK files that call this function.
+ """
-
- runtime.cxx_library(
- name = "op_where",
- srcs = [
- "op_where.cpp",
- ],
- platforms = CXX,
- deps = [
- "//executorch/kernels/portable/cpu/util:all_deps",
- "//executorch/kernels/portable/cpu/pattern:all_deps",
- "//executorch/runtime/kernel:kernel_includes",
- "//executorch/backends/cadence/hifi/kernels:kernels",
- "//executorch/kernels/portable/cpu/util:elementwise_util",
- "//executorch/backends/cadence/hifi/third-party/nnlib:nnlib-extensions"
- ],
- visibility = [
- "//executorch/backends/cadence/...",
- "@EXECUTORCH_CLIENTS",
- ],
- )
+ # Define build targets for all operators registered in the tables above.
+ for op in OPERATORS:
+ define_operator(op)
diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c
index 50d24c8bae..7d95e536c9 100644
--- a/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c
+++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_minimum_maximum_f32.c
@@ -843,4 +843,3 @@ WORD32 xa_nn_elm_minimum_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
}
#endif
-
diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_8.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_8.c
new file mode 100644
index 0000000000..b069035dc9
--- /dev/null
+++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_8.c
@@ -0,0 +1,232 @@
+/*******************************************************************************
+* Copyright (c) 2018-2024 Cadence Design Systems, Inc.
+*
+* Permission is hereby granted, free of charge, to any person obtaining
+* a copy of this software and associated documentation files (the
+* "Software"), to use this Software with Cadence processor cores only and
+* not with any other processors and platforms, subject to
+* the following conditions:
+*
+* The above copyright notice and this permission notice shall be included
+* in all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+******************************************************************************/
+#include "xa_nnlib_common.h"
+
+#include
+
+/*
+ * Currently only supports upto 5D input tensors.
+ * 1/2/3/4 D input tensors will be scaled up to 5D.
+ * For example, 2x3 -> 1x1x1x2x3.
+ */
+
+WORD32 xa_nn_transpose_8_8(WORD8 * __restrict__ p_out
+ ,const WORD32 *const p_out_shape
+ ,const WORD8 * __restrict__ p_inp
+ ,const WORD32 *const p_inp_shape
+ ,const WORD32 * __restrict__ p_permute_vec
+ ,WORD32 num_out_dims
+ ,WORD32 num_inp_dims)
+{
+ /* NULL pointer checks */
+ XA_NNLIB_ARG_CHK_PTR(p_out, -1);
+ XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
+ XA_NNLIB_ARG_CHK_PTR(p_permute_vec, -1);
+ XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
+ XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1);
+
+ /* Invalid input checks */
+ XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 5)), -1);
+ XA_NNLIB_ARG_CHK_COND((num_out_dims != num_inp_dims), -1);
+
+ int itr = 0;
+ for(itr=0; itr < num_inp_dims; itr++)
+ {
+ XA_NNLIB_ARG_CHK_COND((p_inp_shape[itr] <= 0), -1);
+ }
+ for(itr=0; itr < num_out_dims; itr++)
+ {
+ XA_NNLIB_ARG_CHK_COND((p_out_shape[itr] <= 0), -1);
+ }
+
+ /* Output shape provided must be correct based on input
+ * shape and permute values */
+ for(itr=0; itr < num_out_dims; itr++)
+ {
+ int output_dim = p_out_shape[itr];
+ int expected_dim = p_inp_shape[p_permute_vec[itr]];
+ XA_NNLIB_ARG_CHK_COND((output_dim != expected_dim), -1);
+ }
+
+ /* Pointer alignment checks */
+ XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD8), -1);
+ XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(WORD8), -1);
+ XA_NNLIB_ARG_CHK_ALIGN(p_permute_vec, sizeof(WORD32), -1);
+ XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
+ XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1);
+
+ /* Shift all dim with 1 in the outer part */
+ int eff_output_shape[5];
+ int eff_permute_vec[5];
+
+ for(int i = 0; i < num_out_dims; i++)
+ {
+ eff_output_shape[i] = p_out_shape[i];
+ eff_permute_vec[i] = p_permute_vec[i];
+ }
+
+ int one_i=num_out_dims-1, non_one_i=num_out_dims-1;
+ while(one_i > 0 && non_one_i >=0){
+ while(one_i > 0 && eff_output_shape[one_i]!=1){
+ one_i--;
+ }
+ non_one_i = one_i;
+ while(non_one_i >= 0 && eff_output_shape[non_one_i]==1)
+ {
+ non_one_i--;
+ }
+ if(one_i > 0 && non_one_i >=0){
+ int temp;
+ /*swap output_shape*/
+ {
+ temp = eff_output_shape[one_i];
+ eff_output_shape[one_i] = eff_output_shape[non_one_i];
+ eff_output_shape[non_one_i] = temp;
+ }
+ /*swap permute_vec*/
+ {
+ temp = eff_permute_vec[one_i];
+ eff_permute_vec[one_i] = eff_permute_vec[non_one_i];
+ eff_permute_vec[non_one_i] = temp;
+ }
+
+ }
+ }
+
+
+ /* Promoting lesser dim tensors to 5D tensors.
+ * Also updating the permute_vec and shapes as needed for optimization */
+ int p_5D_inp_shape[5] = {1, 1, 1, 1, 1};
+ int p_5D_out_shape[5] = {1, 1, 1, 1, 1};
+ int p_5D_permute_vec[5] = {0, 1, 2, 3, 4};
+
+ /* Check if any inner inp dimension is same in the output */
+ int last_dim_same = 1, last_n_same_dim = 0;
+ itr = num_inp_dims - 1;
+ while(itr >= 0)
+ {
+ last_n_same_dim = (last_dim_same && (eff_permute_vec[itr] == itr)) ? (last_n_same_dim + 1) : last_n_same_dim;
+ last_dim_same = (eff_permute_vec[itr] == itr) ? last_dim_same & 1 : last_dim_same & 0;
+ itr--;
+ }
+
+ int dims_added = 5 - num_inp_dims;
+ itr = num_inp_dims - 1;
+ int same_count = last_n_same_dim;
+ int count = 4;
+ while(itr >= 0)
+ {
+ p_5D_inp_shape[count] = (same_count > 0) ? p_5D_inp_shape[count]*p_inp_shape[itr] : p_inp_shape[itr];
+ p_5D_out_shape[count] = (same_count > 0) ? p_5D_out_shape[count]*eff_output_shape[itr] : eff_output_shape[itr];
+ same_count--;
+ itr--;
+ count = (same_count > 0) ? count : count - 1;
+ }
+
+ itr = num_inp_dims - 1;
+ same_count = (last_n_same_dim) ? num_inp_dims - (last_n_same_dim - 1) : 0;
+ count = 4;
+ while(itr >= 0)
+ {
+ p_5D_permute_vec[count] = (same_count > 0) ? eff_permute_vec[itr-(last_n_same_dim - 1)] + dims_added + last_n_same_dim - 1 : eff_permute_vec[itr] + dims_added;
+ same_count--;
+ itr--;
+ count--;
+ }
+
+ int out_dim0, out_dim1, out_dim2, out_dim3, out_dim4;
+ int inp_dim1, inp_dim2, inp_dim3, inp_dim4;
+ int inp_stride[5];
+
+ out_dim0 = p_5D_out_shape[0];
+ out_dim1 = p_5D_out_shape[1];
+ out_dim2 = p_5D_out_shape[2];
+ out_dim3 = p_5D_out_shape[3];
+ out_dim4 = p_5D_out_shape[4];
+
+ inp_dim1 = p_5D_inp_shape[1];
+ inp_dim2 = p_5D_inp_shape[2];
+ inp_dim3 = p_5D_inp_shape[3];
+ inp_dim4 = p_5D_inp_shape[4];
+
+ inp_stride[0] = inp_dim1*inp_dim2*inp_dim3*inp_dim4;
+ inp_stride[1] = inp_dim2*inp_dim3*inp_dim4;
+ inp_stride[2] = inp_dim3*inp_dim4;
+ inp_stride[3] = inp_dim4;
+ inp_stride[4] = 1;
+
+ if(last_n_same_dim)
+ {
+ int itr0, itr1, itr2, itr3;
+ WORD8 *p_inp0 = (WORD8*)p_inp;
+ for(itr0 = 0; itr0 < out_dim0; itr0++)
+ {
+ WORD8 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]);
+#pragma loop_count min=1
+ for(itr1 = 0; itr1 < out_dim1; itr1++)
+ {
+ WORD8 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]);
+#pragma loop_count min=1
+ for(itr2 = 0; itr2 < out_dim2; itr2++)
+ {
+ WORD8 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]);
+#pragma loop_count min=1
+ for(itr3 = 0; itr3 < out_dim3; itr3++, p_out+=out_dim4)
+ {
+ WORD8 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]);
+ memcpy(p_out, p_inp4, out_dim4);
+ }
+ }
+ }
+ }
+ }
+ else
+ {
+ int itr0, itr1, itr2, itr3, itr4;
+ WORD8 *p_inp0 = (WORD8*)p_inp;
+ for(itr0 = 0; itr0 < out_dim0; itr0++)
+ {
+ WORD8 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]);
+ for(itr1 = 0; itr1 < out_dim1; itr1++)
+ {
+ WORD8 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]);
+ for(itr2 = 0; itr2 < out_dim2; itr2++)
+ {
+ WORD8 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]);
+ for(itr3 = 0; itr3 < out_dim3; itr3++)
+ {
+ WORD8 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]);
+ for(itr4 = 0; itr4 < out_dim4; itr4++)
+ {
+ WORD8 d0 = *(p_inp4);
+ p_inp4 += inp_stride[p_5D_permute_vec[4]];
+ *p_out++ = d0;
+
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt
index 3c66796594..bc0f51a236 100644
--- a/backends/qualcomm/CMakeLists.txt
+++ b/backends/qualcomm/CMakeLists.txt
@@ -1,4 +1,5 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
+# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
@@ -199,11 +200,6 @@ target_link_libraries(
#
target_link_options_shared_lib(qnn_executorch_backend)
-#
-# add compile option
-#
-target_compile_options(executorch PUBLIC -DET_EVENT_TRACER_ENABLED)
-
#
# add sources
#
diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py
new file mode 100644
index 0000000000..de4b7ce2cc
--- /dev/null
+++ b/backends/qualcomm/_passes/__init__.py
@@ -0,0 +1,34 @@
+from .annotate_and_quant_scalar import AnnotateAndQuantScalar
+from .annotate_decomposed import AnnotateDecomposed
+from .annotate_quant_attrs import AnnotateQuantAttrs
+from .convert_bmm_to_matmul import ConvertBmmToMatmul
+from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
+from .convert_prelu import ConvertPReLU
+from .convert_to_linear import ConvertToLinear
+from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
+from .fold_qdq import FoldQDQ
+from .i64_to_i32 import I64toI32
+from .layout_transform import LayoutTransform
+from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
+from .recompose_rms_norm import RecomposeRmsNorm
+from .remove_redundancy import RemoveRedundancy
+from .replace_index_put_input import ReplaceIndexPutInput
+
+
+__all__ = [
+ AnnotateAndQuantScalar,
+ AnnotateDecomposed,
+ AnnotateQuantAttrs,
+ ConvertBmmToMatmul,
+ ConvertInterpolateWithUpsample2D,
+ ConvertPReLU,
+ ConvertToLinear,
+ ExpandBroadcastTensorShape,
+ FoldQDQ,
+ I64toI32,
+ LayoutTransform,
+ RecomposePixelUnshuffle,
+ RecomposeRmsNorm,
+ RemoveRedundancy,
+ ReplaceIndexPutInput,
+]
diff --git a/backends/qualcomm/_passes/annotate_and_quant_scalar.py b/backends/qualcomm/_passes/annotate_and_quant_scalar.py
index 1db50694ec..86475c39b1 100644
--- a/backends/qualcomm/_passes/annotate_and_quant_scalar.py
+++ b/backends/qualcomm/_passes/annotate_and_quant_scalar.py
@@ -53,7 +53,9 @@ def _get_source_scalar_node(self, node: torch.fx.Node) -> torch.fx.Node:
if node.op == "placeholder":
if not (shape := node.meta["val"].size()):
return node
- assert f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
+ assert (
+ not shape
+ ), f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
return self._get_source_scalar_node(node.args[0])
def _update_scalar_node_attrs(self, node: torch.fx.Node, quant_attrs: Dict) -> Dict:
diff --git a/backends/qualcomm/_passes/fuse_consecutive_transpose.py b/backends/qualcomm/_passes/fuse_consecutive_transpose.py
index c81818e00e..16ce380307 100644
--- a/backends/qualcomm/_passes/fuse_consecutive_transpose.py
+++ b/backends/qualcomm/_passes/fuse_consecutive_transpose.py
@@ -15,8 +15,18 @@
class FuseConsecutiveTranspose(ExportPass):
"""
- This pass fuses consecutive transpose / permute into one to reduce runtime
- overhead
+ This pass fuses consecutive transpose / permute into one or none to reduce runtime
+ overhead.
+ To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
+ by cloning transpose.
+ Example:
+ Before clone transpose:
+ relu -> permute1 ─> permute2
+ |──────> permute3
+
+ After clone transpose:
+ relu ─> permute1 ──────> permute2
+ |───> permute4(new) ─> permute3
"""
def __init__(self):
@@ -27,6 +37,30 @@ def __init__(self):
self.visited = set()
self.nodes = []
+ def _clone_transpose(
+ self, graph_module: torch.fx.GraphModule
+ ) -> torch.fx.GraphModule:
+ graph = graph_module.graph
+ for n in graph_module.graph.nodes:
+ if n.target in self.op_map:
+ users = [user for user in list(n.users) if user.target in self.op_map]
+ if len(users) > 1:
+ for i in range(1, len(users)):
+ with graph.inserting_after(n):
+ clone_permute_node = graph.create_node(
+ "call_function",
+ exir_ops.edge.aten.permute_copy.default,
+ (n.args[0], n.args[1]),
+ )
+ clone_permute_node.meta = n.meta
+ users[i].replace_input_with(n, clone_permute_node)
+
+ def _is_dispensable(self, axis_order):
+ for index, value in enumerate(axis_order):
+ if index != value:
+ return False
+ return True
+
def _traverse(self, node):
if node in self.visited or node.target not in self.op_map:
return
@@ -34,47 +68,50 @@ def _traverse(self, node):
self.nodes.append(node)
self.visited.add(node)
next_users = [n for n in list(node.users) if n.target in self.op_map]
+
+ assert (
+ len(next_users) <= 1
+ ), "Each permute node should have at most 1 permute output node after _clone_transpose"
if not next_users:
return
-
- if len(next_users) == 1:
- self._traverse(list(node.users)[0])
else:
- raise NotImplementedError(
- f"Check the node {node}, wich encounter mutilple permute output case"
- )
+ self._traverse(list(node.users)[0])
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
graph = graph_module.graph
for n in graph_module.graph.nodes:
self._traverse(n)
if len(self.nodes) > 1:
- permute_order = []
input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
input_shape = input_node.meta["val"].shape
axis_order = torch.arange(len(input_shape)).tolist()
for node in self.nodes:
- permute_order.append(node.args[1])
axis_order = [axis_order[i] for i in node.args[1]]
- with graph.inserting_after(input_node):
- permute_op = exir_ops.edge.aten.permute_copy.default
- permute_node = graph.create_node(
- "call_function", permute_op, (input_node, axis_order)
- )
- users = output_node.users.copy()
- for user in users:
- user.replace_input_with(output_node, permute_node)
-
- # copy metadata
- permute_node.meta = output_node.meta
- # Without "qnn_permute", we might obtain wrong input shape
- if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
- permute_node.meta[QCOM_INSERTED_PERMUTE] = True
+ # If axis order is just [0,1,2,3], we ignore permute node
+ if self._is_dispensable(axis_order):
+ for user in output_node.users.copy():
+ user.replace_input_with(output_node, n.args[0])
+ else:
+ with graph.inserting_after(input_node):
+ permute_op = exir_ops.edge.aten.permute_copy.default
+ permute_node = graph.create_node(
+ "call_function", permute_op, (input_node, axis_order)
+ )
+ users = output_node.users.copy()
+ for user in users:
+ user.replace_input_with(output_node, permute_node)
+
+ # copy metadata
+ permute_node.meta = output_node.meta
+ # Without "qnn_permute", we might obtain wrong input shape
+ if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
+ permute_node.meta[QCOM_INSERTED_PERMUTE] = True
# clear current stack
self.nodes = []
def call(self, graph_module: torch.fx.GraphModule):
+ self._clone_transpose(graph_module)
self._fuse(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py
index 1d2171cc37..29c747d1a1 100644
--- a/backends/qualcomm/_passes/i64_to_i32.py
+++ b/backends/qualcomm/_passes/i64_to_i32.py
@@ -3,6 +3,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+from typing import FrozenSet
+
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
from executorch.exir.dialects._ops import ops as exir_ops
@@ -15,9 +17,14 @@ class I64toI32(ExportPass):
Cast unsupported int64 datatype into int32.
"""
- def __init__(self, edge_program: torch.export.ExportedProgram):
+ def __init__(
+ self,
+ edge_program: torch.export.ExportedProgram,
+ skip_node: FrozenSet[str] = frozenset(),
+ ):
super(I64toI32, self).__init__()
self.edge_program = edge_program
+ self.skip_node = skip_node
# pyre-ignore[4]
self.copy_op = exir_ops.edge.aten._to_copy.default
@@ -42,6 +49,8 @@ def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
for n in graph_module.graph.nodes:
+ if n.target in self.skip_node:
+ continue
if is_constant(n, self.edge_program):
param = get_parameter(n, self.edge_program)
if param.dtype == torch.int64:
diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py
index 11aad02a0c..83b729f3c4 100644
--- a/backends/qualcomm/_passes/insert_requantize.py
+++ b/backends/qualcomm/_passes/insert_requantize.py
@@ -89,15 +89,9 @@ def _single_output_annotation(
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
# {quant_attr: user_node_name_list}
group_quant_attr_dict = self._invert_dict(requantize_dict)
- # TODO: If users of the node contain output node,
- # we replace the node with to_copy op. However, it would
- # be problem when the node has multiple to_copy ops
- add_output = len(group_quant_attr_dict) == 1
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
user_nodes_copy = user_nodes.copy()
- if add_output:
- user_nodes_copy.append("output")
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py
index 098910ed86..ccc34d3a52 100644
--- a/backends/qualcomm/_passes/layout_transform.py
+++ b/backends/qualcomm/_passes/layout_transform.py
@@ -30,6 +30,7 @@ class LayoutTransform(ExportPass):
"""
layout_sensitive_ops = {
+ exir_ops.edge.aten.adaptive_avg_pool2d.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py
index ac6525ae76..a606a21c62 100755
--- a/backends/qualcomm/_passes/utils.py
+++ b/backends/qualcomm/_passes/utils.py
@@ -43,3 +43,63 @@ def get_quant_attrs(
quant_attrs[QCOM_ENCODING] = quant_node.target
return quant_attrs
+
+
+def get_passes_dependency_for_capture_program():
+ """
+ This function records the dependencies for passes used in the capture_program.
+
+ It returns a dictionary where the keys are pass classes and the values are lists of
+ dependencies required by each pass. This helps in managing and organizing the sequence
+ of passes needed for the capture_program to function correctly.
+
+ Returns:
+ dict: A dictionary mapping each pass to its corresponding list of dependencies.
+ """
+ from executorch.backends.qualcomm._passes import (
+ AnnotateAndQuantScalar,
+ AnnotateDecomposed,
+ AnnotateQuantAttrs,
+ ConvertBmmToMatmul,
+ ConvertInterpolateWithUpsample2D,
+ ConvertPReLU,
+ ConvertToLinear,
+ ExpandBroadcastTensorShape,
+ FoldQDQ,
+ I64toI32,
+ LayoutTransform,
+ RecomposePixelUnshuffle,
+ RecomposeRmsNorm,
+ RemoveRedundancy,
+ ReplaceIndexPutInput,
+ )
+
+ return {
+ RecomposePixelUnshuffle: [RemoveRedundancy],
+ RecomposeRmsNorm: [RemoveRedundancy],
+ ConvertToLinear: [RecomposePixelUnshuffle],
+ ConvertPReLU: [RemoveRedundancy],
+ ConvertBmmToMatmul: [ConvertToLinear],
+ ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
+ I64toI32: [RemoveRedundancy],
+ AnnotateQuantAttrs: [
+ RecomposePixelUnshuffle,
+ RecomposeRmsNorm,
+ ConvertToLinear,
+ ConvertPReLU,
+ ConvertBmmToMatmul,
+ ConvertInterpolateWithUpsample2D,
+ ],
+ AnnotateAndQuantScalar: [
+ AnnotateQuantAttrs,
+ ],
+ AnnotateDecomposed: [RemoveRedundancy],
+ FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
+ ExpandBroadcastTensorShape: [RemoveRedundancy],
+ LayoutTransform: [
+ AnnotateQuantAttrs,
+ AnnotateAndQuantScalar,
+ ExpandBroadcastTensorShape,
+ ],
+ ReplaceIndexPutInput: [LayoutTransform],
+ }
diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py
index 61ed30679e..7a4d6d764b 100644
--- a/backends/qualcomm/builders/__init__.py
+++ b/backends/qualcomm/builders/__init__.py
@@ -7,6 +7,7 @@
from . import (
node_visitor,
op_abs,
+ op_adaptive_avg_pool2d,
op_add,
op_arange,
op_avg_pool2d,
@@ -78,6 +79,7 @@
__all__ = [
node_visitor,
op_abs,
+ op_adaptive_avg_pool2d,
op_add,
op_arange,
op_avg_pool2d,
diff --git a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py
new file mode 100644
index 0000000000..c944e1646e
--- /dev/null
+++ b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py
@@ -0,0 +1,125 @@
+# Copyright (c) Qualcomm Innovation Center, Inc.
+# All rights reserved
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import warnings
+from typing import Dict
+
+import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
+import numpy as np
+
+import torch
+
+from .node_visitor import NodeVisitor, register_node_visitor
+from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
+
+
+@register_node_visitor
+class AdaptiveAvgPool2D(NodeVisitor):
+ target = ["aten.adaptive_avg_pool2d.default"]
+
+ def __init__(self, *args) -> None:
+ super().__init__(*args)
+
+ def define_node(
+ self,
+ node: torch.fx.Node,
+ nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
+ ) -> PyQnnWrapper.PyQnnOpWrapper:
+
+ input_node = node.args[0]
+ input_tensor = self.get_tensor(input_node, node)
+ input_tensor_wrapper = self.define_tensor(
+ input_node,
+ node,
+ input_tensor,
+ PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
+ nodes_to_wrappers,
+ )
+
+ input_height = input_tensor.shape[1]
+ input_width = input_tensor.shape[2]
+
+ output_height = node.args[1][0]
+ output_width = node.args[1][1]
+
+ filter_height = input_height // output_height
+ filter_width = input_width // output_width
+ filter = [filter_height, filter_width]
+ filter_shape = [len(filter)]
+
+ stride_height = filter_height
+ stride_width = filter_width
+ stride = [stride_height, stride_width]
+ stride_shape = [len(stride)]
+
+ height = (output_height - 1) * stride_height + filter_height - input_height
+ width = (output_width - 1) * stride_width + filter_width - input_width
+ if height % 2 != 0 or width % 2 != 0:
+ warnings.warn(
+ "[QNN Delegate Op Builder]: Height or Width is not divisble by 2 with no remainder, fall back op",
+ stacklevel=1,
+ )
+ return
+
+ padding_height = height / 2
+ padding_width = width / 2
+ padding = [padding_height, padding_width]
+ padding_shape = [2, 2]
+
+ out_tensor = self.get_tensor(node, node)
+ output_tensor_wrapper = self.define_tensor(
+ node,
+ node,
+ out_tensor,
+ PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
+ nodes_to_wrappers,
+ )
+
+ adaptive_avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
+ node.name,
+ QNN_OP_PACKAGE_NAME_QTI_AISW,
+ OpPoolAvg2d.op_name,
+ )
+
+ adaptive_avg_pool2d_op.AddInputTensors([input_tensor_wrapper])
+ adaptive_avg_pool2d_op.AddOutputTensors([output_tensor_wrapper])
+
+ adaptive_avg_pool2d_op.AddTensorParam(
+ OpPoolAvg2d.param_filter_size,
+ PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
+ len(filter_shape),
+ filter_shape,
+ np.array(
+ filter,
+ dtype=np.uint32,
+ ),
+ True,
+ )
+
+ adaptive_avg_pool2d_op.AddTensorParam(
+ OpPoolAvg2d.param_stride,
+ PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
+ len(stride_shape),
+ stride_shape,
+ np.array(
+ stride,
+ dtype=np.uint32,
+ ),
+ True,
+ )
+
+ adaptive_avg_pool2d_op.AddTensorParam(
+ OpPoolAvg2d.param_pad_amount,
+ PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
+ len(padding_shape),
+ padding_shape,
+ np.array(
+ [[padding[0], padding[0]], [padding[1], padding[1]]],
+ dtype=np.uint32,
+ ),
+ True,
+ )
+
+ return adaptive_avg_pool2d_op
diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py
index 2006c71648..06f822014e 100644
--- a/backends/qualcomm/builders/op_layer_norm.py
+++ b/backends/qualcomm/builders/op_layer_norm.py
@@ -63,15 +63,19 @@ def define_node(
nodes_to_wrappers,
)
+ layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
+
bias_node = node.args[3]
- bias_tensor = get_parameter(bias_node, self.edge_program)
- bias_tensor_wrapper = self.define_tensor(
- bias_node,
- node,
- bias_tensor,
- PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
- nodes_to_wrappers,
- )
+ if bias_node is not None:
+ bias_tensor = get_parameter(bias_node, self.edge_program)
+ bias_tensor_wrapper = self.define_tensor(
+ bias_node,
+ node,
+ bias_tensor,
+ PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
+ nodes_to_wrappers,
+ )
+ layer_norm_input_tensors.append(bias_tensor_wrapper)
epsilon = node.args[4]
@@ -89,9 +93,7 @@ def define_node(
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpLayerNorm.op_name,
)
- layer_norm_op.AddInputTensors(
- [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
- )
+ layer_norm_op.AddInputTensors(layer_norm_input_tensors)
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
layer_norm_op.AddScalarParam(
OpLayerNorm.param_epsilon,
diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py
index d1daa6c1e5..e5b4778312 100644
--- a/backends/qualcomm/builders/op_rms_norm.py
+++ b/backends/qualcomm/builders/op_rms_norm.py
@@ -66,7 +66,7 @@ def define_node(
nodes_to_wrappers,
)
- # Fake node, nn moudle seems to be inconsistant with document
+ # Fake node, nn module seems to be inconsistant with document
bias_tensor = torch.zeros(weight_tensor.shape)
bias_node = torch.fx.Node(
node.graph,
diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py
index e1792cb183..8bf2265fb5 100644
--- a/backends/qualcomm/quantizer/annotators.py
+++ b/backends/qualcomm/quantizer/annotators.py
@@ -512,6 +512,11 @@ def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
+@register_annotator([torch.ops.aten.square.default])
+def annotate_square(node: Node, quantization_config: QuantizationConfig) -> None:
+ annotate_single_in_single_out(node, quantization_config)
+
+
@register_annotator([torch.ops.aten.gelu.default])
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py
index d1c1757cc1..33237f3beb 100644
--- a/backends/qualcomm/quantizer/custom_annotation.py
+++ b/backends/qualcomm/quantizer/custom_annotation.py
@@ -14,17 +14,80 @@
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
-from torch.ao.quantization.observer import MinMaxObserver
+from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
+ QuantizationSpec,
SharedQuantizationSpec,
)
from torch.fx import Node
-def annotate_matmul_16a8w( # noqa: C901
- gm: torch.fx.GraphModule, traverse_input1=True
-) -> None:
+def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
+ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
+ input_qspec_map = {}
+ input_act = node.args[0]
+ input_spec = quantization_config.input_activation
+ input_qspec_map[input_act] = input_spec
+
+ weight = node.args[1]
+ input_qspec_map[weight] = quantization_config.weight
+
+ node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
+ input_qspec_map=input_qspec_map,
+ output_qspec=quantization_config.output_activation,
+ _annotated=True,
+ )
+
+ quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
+ torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
+ )
+ for node in gm.graph.nodes:
+ if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
+ if "nn_module_stack" in node.meta:
+ module_values_list = list(node.meta["nn_module_stack"].values())
+ full_qualified_name = module_values_list[-1][0]
+ if full_qualified_name == "output.conv":
+ annotate_conv2d(
+ node, quantization_config=quantization_config_16a8w_per_channel
+ )
+
+
+def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
+ for node in gm.graph.nodes:
+ if node.op == "output":
+ for index, prefill_output in enumerate(node.args[0]):
+ kv_quant_attr = kv_quant_attrs[index]
+ fixed_observer = FixedQParamsObserver.with_args(
+ scale=kv_quant_attr[0],
+ zero_point=kv_quant_attr[1],
+ quant_min=kv_quant_attr[2],
+ quant_max=kv_quant_attr[3],
+ dtype=kv_quant_attr[4],
+ qscheme=torch.torch.per_tensor_affine,
+ )
+
+ fixed_output_spec = QuantizationSpec(
+ quant_min=kv_quant_attr[2],
+ quant_max=kv_quant_attr[3],
+ dtype=kv_quant_attr[4],
+ ch_axis=0,
+ observer_or_fake_quant_ctr=fixed_observer,
+ )
+
+ input_qspec_map = {}
+ for input in prefill_output.args:
+ if isinstance(input, Node):
+ input_qspec_map[input] = fixed_output_spec
+
+ prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
+ input_qspec_map=input_qspec_map,
+ output_qspec=fixed_output_spec,
+ _annotated=True,
+ )
+
+
+def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
"""
This function is specific for matmul op 16a8w.
For k, we will tag such as the below, and
@@ -142,8 +205,7 @@ def annotate_matmul_input1(node: Node):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
- if traverse_input1:
- annotate_matmul_input1(node.args[1])
+ annotate_matmul_input1(node.args[1])
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp
index f2650301a3..83a94fdfdf 100644
--- a/backends/qualcomm/runtime/QnnManager.cpp
+++ b/backends/qualcomm/runtime/QnnManager.cpp
@@ -154,8 +154,9 @@ Error QnnManager::RegisterMem(
const std::shared_ptr& tensor_wrapper) {
SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager();
// Not enable shared buffer
- if (!options_->shared_buffer())
+ if (!options_->shared_buffer()) {
return Error::Internal;
+ }
if (backend_params_ptr_->qnn_mem_manager_ptr_ == nullptr) {
QNN_EXECUTORCH_LOG_WARN(
diff --git a/backends/qualcomm/runtime/QnnManager.h b/backends/qualcomm/runtime/QnnManager.h
index 0157ee5837..17294afbd8 100644
--- a/backends/qualcomm/runtime/QnnManager.h
+++ b/backends/qualcomm/runtime/QnnManager.h
@@ -145,7 +145,7 @@ class QnnManager {
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8,
executorch::aten::ScalarType::Byte},
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16,
- executorch::aten::ScalarType::Bits16},
+ executorch::aten::ScalarType::UInt16},
};
};
} // namespace qnn
diff --git a/backends/qualcomm/runtime/backends/QnnMemManager.h b/backends/qualcomm/runtime/backends/QnnMemManager.h
index 664f717dc0..a0bdafab7b 100644
--- a/backends/qualcomm/runtime/backends/QnnMemManager.h
+++ b/backends/qualcomm/runtime/backends/QnnMemManager.h
@@ -77,7 +77,7 @@ class QnnMemManager {
Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16},
{executorch::aten::ScalarType::Byte,
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8},
- {executorch::aten::ScalarType::Bits16,
+ {executorch::aten::ScalarType::UInt16,
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16},
};
};
diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh
index ed77a87351..506bb92752 100755
--- a/backends/qualcomm/scripts/build.sh
+++ b/backends/qualcomm/scripts/build.sh
@@ -87,6 +87,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
-DANDROID_ABI='arm64-v8a' \
-DANDROID_NATIVE_API_LEVEL=23 \
+ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-B$BUILD_ROOT
@@ -101,6 +102,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DANDROID_ABI='arm64-v8a' \
-DANDROID_NATIVE_API_LEVEL=23 \
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
+ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-B$EXAMPLE_ROOT
@@ -125,6 +127,7 @@ if [ "$BUILD_X86_64" = true ]; then
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
+ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py
index d66aa34e5a..3ad183c2c2 100644
--- a/backends/qualcomm/tests/models.py
+++ b/backends/qualcomm/tests/models.py
@@ -16,6 +16,15 @@ def forward(self, x):
return torch.abs(x)
+class AdaptiveAvgPool2D(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
+ return adaptive_avg_pool(x)
+
+
class Add(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -685,15 +694,24 @@ def forward(self, x):
class LayerNorm(torch.nn.Module):
- def __init__(self):
+ def __init__(self, bias=True):
super().__init__()
- self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
+ self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias)
self.linear = torch.nn.Linear(768, 196)
def forward(self, x):
return self.linear(self.layer_norm(x))
+class LayerNormAdd(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layer_norm = torch.nn.LayerNorm([512], eps=1e-6, bias=False)
+
+ def forward(self, x, y):
+ return self.layer_norm(x) + y
+
+
class LeakyReLUDefault(torch.nn.Module):
def __init__(self):
super().__init__()
diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py
index 30ed34032f..498ee4ea68 100644
--- a/backends/qualcomm/tests/test_qnn_delegate.py
+++ b/backends/qualcomm/tests/test_qnn_delegate.py
@@ -37,8 +37,9 @@
skip_annotation,
update_spill_fill_size,
)
+from executorch.examples.models.llama.llama_transformer import MOEFeedForward
-from executorch.examples.models.llama.llama_transformer import ModelArgs, MOEFeedForward
+from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.qualcomm.utils import setup_common_args_and_variables
@@ -97,6 +98,11 @@ def test_qnn_backend_abs(self):
sample_input = (torch.randn(1, 2, 3, 4),)
self.lower_module_and_test_output(module, sample_input)
+ def test_qnn_backend_adaptive_avg_pool2d(self):
+ module = AdaptiveAvgPool2D() # noqa: F405
+ sample_input = (torch.randn(1, 512, 7, 7),)
+ self.lower_module_and_test_output(module, sample_input)
+
def test_qnn_backend_arange(self):
modules = [
Arange(start=1, end=11, step=1, dtype=torch.int32), # noqa: F405
@@ -432,9 +438,11 @@ def test_qnn_backend_interpolate_nearest_2d(self):
self.lower_module_and_test_output(module, sample_input)
def test_qnn_backend_layer_norm(self):
- module = LayerNorm() # noqa: F405
+ modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
- self.lower_module_and_test_output(module, sample_input)
+ for i, module in enumerate(modules):
+ with self.subTest(i=i):
+ self.lower_module_and_test_output(module, sample_input)
def test_qnn_backend_leaky_relu(self):
test_comb = [
@@ -915,6 +923,12 @@ def test_qnn_backend_abs(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
+ def test_qnn_backend_adaptive_avg_pool2d(self):
+ module = AdaptiveAvgPool2D() # noqa: F405
+ sample_input = (torch.randn(1, 512, 7, 7),)
+ module = self.get_qdq_module(module, sample_input)
+ self.lower_module_and_test_output(module, sample_input)
+
def test_qnn_backend_arange(self):
modules = [
Arange(start=1, end=6, step=0.5, dtype=torch.float32), # noqa: F405
@@ -1280,10 +1294,12 @@ def test_qnn_backend_interpolate_nearest_2d(self):
self.lower_module_and_test_output(module, sample_input)
def test_qnn_backend_layer_norm(self):
- module = LayerNorm() # noqa: F405
+ modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
- module = self.get_qdq_module(module, sample_input)
- self.lower_module_and_test_output(module, sample_input)
+ for i, module in enumerate(modules):
+ with self.subTest(i=i):
+ module = self.get_qdq_module(module, sample_input)
+ self.lower_module_and_test_output(module, sample_input)
def test_qnn_backend_leaky_relu(self):
test_comb = [
@@ -2675,6 +2691,42 @@ def required_envs(self, conditions=None) -> bool:
]
)
+ def test_conv_former(self):
+ if not self.required_envs([self.image_dataset]):
+ self.skipTest("missing required envs")
+
+ cmds = [
+ "python",
+ f"{self.executorch_root}/examples/qualcomm/oss_scripts/conv_former.py",
+ "--dataset",
+ self.image_dataset,
+ "--artifact",
+ self.artifact_dir,
+ "--build_folder",
+ self.build_folder,
+ "--device",
+ self.device,
+ "--model",
+ self.model,
+ "--ip",
+ self.ip,
+ "--port",
+ str(self.port),
+ ]
+ if self.host:
+ cmds.extend(["--host", self.host])
+
+ p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
+ with Listener((self.ip, self.port)) as listener:
+ conn = listener.accept()
+ p.communicate()
+ msg = json.loads(conn.recv())
+ if "Error" in msg:
+ self.fail(msg["Error"])
+ else:
+ self.assertGreaterEqual(msg["top_1"], 60)
+ self.assertGreaterEqual(msg["top_5"], 80)
+
def test_dino_v2(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")
@@ -3529,7 +3581,7 @@ def test_stories_single_llama(self):
cmds = [
"python",
- f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
+ f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
@@ -3556,6 +3608,8 @@ def test_stories_single_llama(self):
"16a4w",
"--temperature",
"0",
+ "--llama_model",
+ "stories110m",
]
if self.host:
cmds.extend(["--host", self.host])
diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py
index 1cc51690ff..4f73d331ad 100644
--- a/backends/qualcomm/utils/constants.py
+++ b/backends/qualcomm/utils/constants.py
@@ -26,8 +26,8 @@
QCOM_SCALE_OFFSET = "scale_offset"
QCOM_ZERO_POINT = "zero_point"
QCOM_ZERO_POINTS = "zero_points"
-QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape"
-QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant"
+QCOM_PASS_ACTIVATE_KEY = "activate"
+QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY = "args_kwargs_defaults"
# constants in backends/qualcomm/tests
QCOM_ANNOTATION = "annotation"
diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index a4acae9585..1bcfa3a6f6 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -3,13 +3,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-
+import inspect
import operator
import re
import time
import warnings
from collections import OrderedDict
-from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
@@ -46,6 +46,9 @@
from executorch.backends.qualcomm._passes.replace_index_put_input import (
ReplaceIndexPutInput,
)
+from executorch.backends.qualcomm._passes.utils import (
+ get_passes_dependency_for_capture_program,
+)
from executorch.backends.qualcomm.builders.node_visitor import (
QNN_QUANT_TYPE_MAP,
@@ -74,8 +77,8 @@
option_to_flatbuffer,
)
from executorch.backends.qualcomm.utils.constants import (
- QCOM_PASS_EXPAND_BROADCAST_SHAPE,
- QCOM_PASS_SKIP_ADVANCED_REQUANT,
+ QCOM_PASS_ACTIVATE_KEY,
+ QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY,
QCOM_QNN_COMPILE_SPEC,
QCOM_QUANTIZED_IO,
)
@@ -89,10 +92,12 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.capture import ExecutorchBackendConfig
from executorch.exir.lowered_backend_module import LoweredBackendModule
+from executorch.exir.passes import PassManager
from executorch.exir.program._program import _get_updated_graph_signature
-from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
+from torch._decomp import core_aten_decompositions, remove_decompositions
from torch.export.exported_program import ExportedProgram
from torch.fx import passes
+from torch.fx.passes.infra.pass_manager import this_before_that_pass_constraint
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.library import Library
@@ -283,9 +288,10 @@ def set_spec(module, options):
def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
- source_decompositions = torch_core_aten_decompositions()
+ source_decompositions = core_aten_decompositions()
# The below super ops are supported by QNN
- remove_decompositions = [
+ skip_decompositions = [
+ torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.pixel_unshuffle.default,
torch.ops.aten.hardsigmoid.default,
@@ -293,39 +299,92 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
torch.ops.aten._safe_softmax.default,
]
- for key in remove_decompositions:
- source_decompositions.pop(key)
+ remove_decompositions(source_decompositions, skip_decompositions)
return source_decompositions
+def get_capture_program_passes():
+ """
+ Defines and returns the default ordered passes for the capture program.
+ This function creates an OrderedDict containing a series of default passes.
+
+ Returns:
+ OrderedDict: An ordered dictionary containing all default passes along with their activation status and initialization parameters.
+ """
+
+ # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default.
+ # If a pass is activated, it will be executed by default.
+ default_passes_and_setting = [
+ (RemoveRedundancy, True),
+ (RecomposePixelUnshuffle, True),
+ (RecomposeRmsNorm, True),
+ (ConvertToLinear, True),
+ (ConvertPReLU, True),
+ (ConvertBmmToMatmul, True),
+ (ConvertInterpolateWithUpsample2D, True),
+ (I64toI32, True),
+ (AnnotateQuantAttrs, True),
+ (AnnotateAndQuantScalar, True),
+ (AnnotateDecomposed, True),
+ (FoldQDQ, True),
+ (ExpandBroadcastTensorShape, False),
+ (LayoutTransform, True),
+ (ReplaceIndexPutInput, True),
+ ]
+
+ passes = OrderedDict()
+ for p, act in default_passes_and_setting:
+ init_signature = inspect.signature(p.__init__)
+
+ args_kwargs_defaults = {
+ k: v.default if v.default is not inspect.Parameter.empty else None
+ for k, v in init_signature.parameters.items()
+ if k != "self"
+ }
+
+ passes[p] = {
+ QCOM_PASS_ACTIVATE_KEY: act,
+ QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: args_kwargs_defaults,
+ }
+
+ return passes
+
+
+def _topological_sort_passes(passes: OrderedDict):
+ dep_table = get_passes_dependency_for_capture_program()
+ pm = PassManager()
+ for p in passes:
+ pm.add_pass(p)
+
+ for that, these in dep_table.items():
+ for this in these:
+ pm.add_constraint(this_before_that_pass_constraint(this, that))
+
+ pm.solve_constraints()
+ sorted_passes = OrderedDict()
+ for p in pm.passes:
+ sorted_passes[p] = passes[p]
+ return sorted_passes
+
+
def _transform(
- edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset()
+ edge_program: ExportedProgram, passes_job: OrderedDict = None
) -> ExportedProgram:
# currently ExirExportedProgram.transform does not accept
# changes of input number which was caused by FoldQDQ
# apply passes one by one here to avoid IR capture failure
graph_module = edge_program.graph_module
- RemoveRedundancy()(graph_module)
- RecomposePixelUnshuffle()(graph_module)
- RecomposeRmsNorm()(graph_module)
- ConvertToLinear()(graph_module)
- ConvertPReLU(edge_program)(graph_module)
- ConvertBmmToMatmul()(graph_module)
- ConvertInterpolateWithUpsample2D()(graph_module)
- I64toI32(edge_program)(graph_module)
- AnnotateQuantAttrs(
- edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
- )(graph_module)
- AnnotateAndQuantScalar(edge_program)(graph_module)
- AnnotateDecomposed(edge_program)(graph_module)
- FoldQDQ()(graph_module)
- # this pass is not necessary for network without layout-sensitive ops
- # enable defaultly will introduce overhead from extra view_copy nodes
- if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config:
- ExpandBroadcastTensorShape()(graph_module)
- LayoutTransform(edge_program)(graph_module)
- ReplaceIndexPutInput(edge_program)(graph_module)
+ passes_job = passes_job if passes_job is not None else get_capture_program_passes()
+ passes_job = _topological_sort_passes(passes_job)
+ for p in passes_job:
+ if not passes_job[p][QCOM_PASS_ACTIVATE_KEY]:
+ continue
+
+ kwargs = passes_job[p][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY]
+ if "edge_program" in kwargs:
+ kwargs["edge_program"] = edge_program
+ p(**kwargs)(graph_module)
# Since QDQ nodes are stripped, update graph signature again to validate program
edge_program._graph_signature = _get_updated_graph_signature(
@@ -339,7 +398,7 @@ def _transform(
def capture_program(
module: torch.nn.Module,
inputs: Tuple[torch.Tensor],
- custom_pass_config: FrozenSet[str] = frozenset(),
+ passes_job: OrderedDict = None,
) -> exir.ExirExportedProgram:
ep = torch.export.export(module, inputs, strict=True)
decomposed_ep = ep.run_decompositions(get_decomp_table())
@@ -350,7 +409,8 @@ def capture_program(
core_ep = ExirExportedProgram(decomposed_ep, False)
core_ep.transform(ConvertBinaryOpsWithScalar())
edge_ep = core_ep.to_edge(qnn_edge_config())
- _transform(edge_ep.exported_program, custom_pass_config)
+
+ _transform(edge_ep.exported_program, passes_job)
return edge_ep
@@ -906,28 +966,34 @@ def generate_multi_graph_program(
def generate_composite_llama_program(
+ llama_model: torch.nn.Module,
graph_names: List[str],
sample_inputs_list: List[Tuple[Any]],
lower_module_dict: Dict[str, List[LoweredBackendModule]],
call_delegate_node_name_dict: Dict[str, List[str]],
call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]],
outputs_dict: Dict[str, List[Tuple[str, int]]],
+ embedding_quantize: str,
backend_config: ExecutorchBackendConfig = None,
constant_methods: Optional[Dict[str, Any]] = None,
) -> ExecutorchProgramManager:
class CompositeLlamaModule(torch.nn.Module):
def __init__(
self,
+ llama_model,
lower_module_list,
call_delegate_node_name_list,
call_delegate_inputs_list,
outputs_list,
+ embedding_quantize,
) -> None:
super().__init__()
+ self.llama_model = llama_model
self.lower_module_list = lower_module_list
self.call_delegate_node_name_list = call_delegate_node_name_list
self.call_delegate_inputs_list = call_delegate_inputs_list
self.outputs_list = outputs_list
+ self.embedding_quantize = embedding_quantize
def reorder(
self,
@@ -960,6 +1026,13 @@ def forward(
}
for num, arg in enumerate(args):
module_input_dict[f"args_{num}"] = arg
+
+ if self.embedding_quantize:
+ hidden_states = self.llama_model.tok_embeddings(tokens)
+ module_input_dict["quantized_decomposed_embedding_4bit_dtype"] = (
+ hidden_states
+ )
+
for lower_module, call_delegate_node_name, call_delegate_inputs in zip(
self.lower_module_list,
self.call_delegate_node_name_list,
@@ -976,10 +1049,12 @@ def forward(
progs_dict = {}
for graph_name, sample_inputs in zip(graph_names, sample_inputs_list):
composite_llama_module = CompositeLlamaModule(
+ llama_model,
lower_module_dict[graph_name],
call_delegate_node_name_dict[graph_name],
call_delegate_inputs_dict[graph_name],
outputs_dict[graph_name],
+ embedding_quantize,
)
prog = torch.export.export(composite_llama_module, sample_inputs)
progs_dict[graph_name] = prog
diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS
index 74048cfb6a..4e60fc7bd7 100644
--- a/backends/vulkan/_passes/TARGETS
+++ b/backends/vulkan/_passes/TARGETS
@@ -30,6 +30,19 @@ runtime.python_library(
]
)
+runtime.python_library(
+ name = "remove_asserts",
+ srcs = ["remove_asserts.py"],
+ visibility = [
+ "//executorch/backends/...",
+ ],
+ deps = [
+ "//caffe2:torch",
+ "//executorch/exir:pass_base",
+ "//executorch/exir/dialects:lib",
+ ],
+)
+
runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -83,6 +96,7 @@ runtime.python_library(
deps = [
":insert_prepack_nodes",
":int4_weight_only_quantizer",
+ ":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":tag_memory_meta_pass"
diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py
index 416339574b..8c29f5488f 100644
--- a/backends/vulkan/_passes/__init__.py
+++ b/backends/vulkan/_passes/__init__.py
@@ -2,6 +2,10 @@
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
VkInt4WeightOnlyQuantizer,
)
+from executorch.backends.vulkan._passes.remove_asserts import (
+ remove_asserts,
+ RemoveAssertsTransform,
+)
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
@@ -13,6 +17,8 @@
__all__ = [
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
+ "remove_asserts",
+ "RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"TagMemoryMetaPass",
diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py
index 7876806d6d..bf1fc28ba5 100644
--- a/backends/vulkan/_passes/insert_prepack_nodes.py
+++ b/backends/vulkan/_passes/insert_prepack_nodes.py
@@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
)
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
+ # Mutable buffers will not be marked as constant, but it might as well be
+ # for the purposes of memory planning. Mark it as a constant tensor so that
+ # it is handled correctly by the memory planning pass.
+ if not node.meta["spec"].const:
+ assert is_param_node(program, node)
+ node.meta["spec"].const = True
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
@@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
- node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
+ node.replace_all_uses_with(
+ prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
+ )
program.graph.eliminate_dead_code()
return program
diff --git a/backends/vulkan/_passes/remove_asserts.py b/backends/vulkan/_passes/remove_asserts.py
new file mode 100644
index 0000000000..835f2ec141
--- /dev/null
+++ b/backends/vulkan/_passes/remove_asserts.py
@@ -0,0 +1,52 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-strict
+
+from typing import Set, Union
+
+import torch
+
+from executorch.exir.dialects.edge._ops import EdgeOpOverload
+from executorch.exir.pass_base import ExportPass, PassResult
+from executorch.exir.program._program import _get_updated_graph_signature
+
+from torch.export.exported_program import ExportedProgram
+
+OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
+
+
+class RemoveAssertsTransform(ExportPass):
+ """
+ Remove operators which perform assertions. These are not possible to execute in
+ Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
+ operators.
+ """
+
+ assert_ops: Set[OpType] = {
+ torch.ops.aten._assert_scalar.default,
+ torch.ops.aten.sym_constrain_range_for_size.default,
+ }
+
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
+ for node in graph_module.graph.nodes:
+ if node.target in self.assert_ops:
+ graph_module.graph.erase_node(node)
+
+ graph_module.graph.eliminate_dead_code()
+ graph_module.recompile()
+ return PassResult(graph_module, True)
+
+
+def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
+ graph_module = edge_program.graph_module
+ RemoveAssertsTransform()(graph_module)
+
+ edge_program._graph_signature = _get_updated_graph_signature(
+ edge_program.graph_signature, graph_module
+ )
+ edge_program._validate()
+ return edge_program
diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py
index 1d08817e26..f2f54404ca 100644
--- a/backends/vulkan/_passes/tag_memory_meta_pass.py
+++ b/backends/vulkan/_passes/tag_memory_meta_pass.py
@@ -23,9 +23,6 @@
from executorch.exir.pass_base import ExportPass, PassResult
-from torch.fx.passes.tools_common import NodeList
-from torch.fx.passes.utils.fuser_utils import topo_sort
-
logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)
@@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
# noqa
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
- sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
-
- for node in sorted_nodes:
+ for node in graph_module.graph.nodes:
if not self.should_annotate(node) or self.should_delay_annotation(node):
continue
diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py
index d70cf93b88..25cf74dc8f 100644
--- a/backends/vulkan/op_registry.py
+++ b/backends/vulkan/op_registry.py
@@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):
@update_features("llama::sdpa_with_kv_cache")
-def register_sdpa_op(features: OpFeatures):
+def register_sdpa_with_kv_cache_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
)
@@ -489,6 +489,16 @@ def register_sdpa_op(features: OpFeatures):
return features
+@update_features(["llama::update_cache", "llama::custom_sdpa"])
+def register_sdpa_ops(features: OpFeatures):
+ features.resize_fn = False
+ features.buffer_impl = False
+ features.texture_impl = TextureImplFeatures(
+ valid_packed_dims={PackedDim.WIDTH},
+ )
+ return features
+
+
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py
index 3c31e0316a..6ff3fa8d70 100644
--- a/backends/vulkan/partitioner/vulkan_partitioner.py
+++ b/backends/vulkan/partitioner/vulkan_partitioner.py
@@ -250,11 +250,19 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool:
self.log_skip(node, "local scalar dense of incompatible op node")
return False
+ features = None
if target not in vulkan_supported_ops:
- self.log_skip(node, "no operator implementation")
- return False
+ # For some ops, i.e. custom ops the name is registered instead of the
+ # OpOverload object.
+ if not isinstance(target, str) and target.name() in vulkan_supported_ops:
+ features = vulkan_supported_ops[target.name()]
+ else:
+ self.log_skip(node, "no operator implementation")
+ return False
+ else:
+ features = vulkan_supported_ops[target]
- features = vulkan_supported_ops[target]
+ assert features is not None
if not features.check_node_fn(node):
self.log_skip(node, "op args not supported")
diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp
index 8bff63d0e8..3d249aab4a 100644
--- a/backends/vulkan/runtime/VulkanBackend.cpp
+++ b/backends/vulkan/runtime/VulkanBackend.cpp
@@ -417,10 +417,10 @@ bool maybe_update_scalar_tensor(
executorch::aten::Tensor& scalar_tensor_src) {
const int32_t cur_val = graph->read_symint(ref);
int32_t scalar_tensor_val = 0;
- exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
- if (dtype == exec_aten::ScalarType::Int) {
+ executorch::aten::ScalarType dtype = scalar_tensor_src.scalar_type();
+ if (dtype == executorch::aten::ScalarType::Int) {
scalar_tensor_val = *scalar_tensor_src.const_data_ptr();
- } else if (dtype == exec_aten::ScalarType::Long) {
+ } else if (dtype == executorch::aten::ScalarType::Long) {
scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr());
}
bool was_updated = false;
diff --git a/backends/vulkan/runtime/graph/ops/glsl/activations.h b/backends/vulkan/runtime/graph/ops/glsl/activations.h
index 94c9e1274d..2ba0ccc467 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/activations.h
+++ b/backends/vulkan/runtime/graph/ops/glsl/activations.h
@@ -42,3 +42,15 @@ vec4 hardsigmoid(vec4 tex) {
hardsigmoid(tex.z),
hardsigmoid(tex.w));
}
+
+float leaky_relu(float x, float negative_slope) {
+ return x * (float(x > 0.0) + negative_slope * float(x <= 0.0));
+}
+
+vec4 leaky_relu(vec4 tex, float negative_slope) {
+ return vec4(
+ leaky_relu(tex.x, negative_slope),
+ leaky_relu(tex.y, negative_slope),
+ leaky_relu(tex.z, negative_slope),
+ leaky_relu(tex.w, negative_slope));
+}
diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
index c05c7e4450..3265a97398 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
+++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
@@ -14,12 +14,12 @@
#define TILE_SIZE ${TILE_SIZE}
-#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}
-
#define BATCH_SIZE_X ${BATCH_SIZE_X}
#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
+#define LOCAL_WG_SIZE 64
+
#define op(X, A, B) ${OPERATOR}
#include "indexing_utils.h"
@@ -30,14 +30,28 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
-${layout_declare_ubo(4, "ivec3", "out_limits")}
-${layout_declare_ubo(5, "ivec4", "in_sizes")}
-${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
-${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
-${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
+
+layout(push_constant) uniform restrict Block {
+ ivec4 out_limits;
+ ivec4 in_sizes;
+ ivec2 kernel_size;
+ ivec2 stride;
+ ivec2 padding;
+ ivec2 dilation;
+ ivec2 overlay_region;
+ int in_group_size;
+ int dummy_padding;
+ float out_min;
+ float out_max;
+};
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
+// For performance improvement, reduce register usage by caching positions in shared memory.
+// Offset index by 1 every 16 points to avoid bank access conflict.
+#define offset_pos_index(index) (index + ((index) >> 4))
+shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE)];
+
/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
@@ -63,6 +77,8 @@ void main() {
return;
}
+ pos_shared[offset_pos_index(gl_LocalInvocationIndex)] = pos;
+
// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;
@@ -109,18 +125,19 @@ void main() {
for (int j = 0; j < TILE_SIZE; j++, kx++) {
prev_kernel_line[j] = texelFetch(t_kernel, ivec2(kx, pos.z), 0);
for (int s = 0; s < BATCH_SIZE_X; s++) {
- sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
+ sum[0][s] = fma(in_texels[j + s], prev_kernel_line[j], sum[0][s]);
}
}
}
}
+ const ivec3 out_pos = pos_shared[offset_pos_index(gl_LocalInvocationIndex)];
for (int y = 0; y < BATCH_SIZE_Y; y++) {
for (int x = 0; x < BATCH_SIZE_X; x++) {
- if (any(greaterThanEqual(ivec3(pos.x + x, pos.y + y, pos.z), out_limits))) {
+ if (any(greaterThanEqual(ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), out_limits.xyz))) {
continue;
}
- imageStore(t_out, ivec3(pos.x + x, pos.y + y, pos.z), op(sum[y][x], out_min, out_max));
+ imageStore(t_out, ivec3(out_pos.x + x, out_pos.y + y, out_pos.z), op(sum[y][x], out_min, out_max));
}
}
}
diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
index d3672f5ec2..9cf6c22c6c 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
@@ -12,7 +12,6 @@ conv2d_dw_output_tile:
TILE_SIZE: 3
BATCH_SIZE_X: 4
BATCH_SIZE_Y: 2
- STRIDE_EQ_DILATION: 0
generate_variant_forall:
DTYPE:
- VALUE: half
@@ -26,15 +25,3 @@ conv2d_dw_output_tile:
- NAME: conv2d_dw_output_tile_5x5_clamp
OPERATOR: clamp(X, A, B)
TILE_SIZE: 5
- - NAME: conv2d_dw_sed_output_tile_3x3
- STRIDE_EQ_DILATION: 1
- - NAME: conv2d_dw_sed_output_tile_3x3_clamp
- OPERATOR: clamp(X, A, B)
- STRIDE_EQ_DILATION: 1
- - NAME: conv2d_dw_sed_output_tile_5x5
- TILE_SIZE: 5
- STRIDE_EQ_DILATION: 1
- - NAME: conv2d_dw_sed_output_tile_5x5_clamp
- OPERATOR: clamp(X, A, B)
- TILE_SIZE: 5
- STRIDE_EQ_DILATION: 1
diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl
index bb70ee1aab..ceadc35779 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl
+++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_sned_output_tile.glsl
@@ -24,11 +24,20 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
-${layout_declare_ubo(4, "ivec3", "out_limits")}
-${layout_declare_ubo(5, "ivec4", "in_sizes")}
-${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
-${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
-${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
+
+layout(push_constant) uniform restrict Block {
+ ivec4 out_limits;
+ ivec4 in_sizes;
+ ivec2 kernel_size;
+ ivec2 stride;
+ ivec2 padding;
+ ivec2 dilation;
+ ivec2 overlay_region;
+ int in_group_size;
+ int dummy_padding;
+ float out_min;
+ float out_max;
+};
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
index 77a334a05e..6757d2a6d4 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
@@ -42,3 +42,5 @@ unary_op:
OPERATOR: hardswish(X)
- NAME: hardsigmoid
OPERATOR: hardsigmoid(X)
+ - NAME: leaky_relu
+ OPERATOR: leaky_relu(X, A)
diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
index 3c367f334d..71b7ce80cc 100644
--- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
@@ -407,7 +407,9 @@ void add_conv2d_node(
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
}
- if (method == Conv2dMethod::Pointwise) {
+ vkapi::ParamsBindList param_buffers;
+ std::vector push_constants;
+ if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
const utils::ivec4 kernel_param_size_stride = {
kernel_params.kernel_size[0],
kernel_params.kernel_size[1],
@@ -420,55 +422,43 @@ void add_conv2d_node(
kernel_params.dilation[0],
kernel_params.dilation[1]};
- graph.execute_nodes().emplace_back(new DispatchNode(
- graph,
- shader,
- wg_size,
- graph.create_local_wg_size(wg_size),
- // Inputs and Outputs
- {{out, vkapi::MemoryAccessType::WRITE},
- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
- // Shader params buffers
- {},
- // Specialization Constants
- {},
- // Resizing Logic
- resize_conv2d_node,
- {weight_data, stride, padding, dilation, transposed, output_padding},
- {
- graph.logical_limits_pc_of(out),
- graph.sizes_pc_of(in),
- PushConstantDataInfo(
- &kernel_param_size_stride, sizeof(kernel_param_size_stride)),
- PushConstantDataInfo(
- &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
- PushConstantDataInfo(
- &extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
- PushConstantDataInfo(&out_params, sizeof(out_params)),
- }));
+ push_constants = {
+ graph.logical_limits_pc_of(out),
+ graph.sizes_pc_of(in),
+ PushConstantDataInfo(
+ &kernel_param_size_stride, sizeof(kernel_param_size_stride)),
+ PushConstantDataInfo(
+ &kernel_param_pad_dial, sizeof(kernel_param_pad_dial)),
+ PushConstantDataInfo(
+ &extra_params, sizeof(extra_params), sizeof(utils::ivec4)),
+ PushConstantDataInfo(&out_params, sizeof(out_params)),
+ };
} else {
- graph.execute_nodes().emplace_back(new DispatchNode(
- graph,
- shader,
- wg_size,
- graph.create_local_wg_size(wg_size),
- // Inputs and Outputs
- {{out, vkapi::MemoryAccessType::WRITE},
- {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
- // Shader params buffers
- {
- t_out->logical_limits_ubo(),
- t_in->sizes_ubo(),
- graph.create_params_buffer(kernel_params),
- graph.create_params_buffer(extra_params),
- graph.create_params_buffer(out_params),
- },
- // Specialization Constants
- {},
- // Resizing Logic
- resize_conv2d_node,
- {weight_data, stride, padding, dilation, transposed, output_padding}));
+ param_buffers = {
+ t_out->logical_limits_ubo(),
+ t_in->sizes_ubo(),
+ graph.create_params_buffer(kernel_params),
+ graph.create_params_buffer(extra_params),
+ graph.create_params_buffer(out_params),
+ };
}
+
+ graph.execute_nodes().emplace_back(new DispatchNode(
+ graph,
+ shader,
+ wg_size,
+ graph.create_local_wg_size(wg_size),
+ // Inputs and Outputs
+ {{out, vkapi::MemoryAccessType::WRITE},
+ {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
+ // Shader params buffers
+ param_buffers,
+ // Specialization Constants
+ {},
+ // Resizing Logic
+ resize_conv2d_node,
+ {weight_data, stride, padding, dilation, transposed, output_padding},
+ push_constants));
}
void add_conv1d_node(
diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
index a78ac0519c..1042c23bcb 100644
--- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
@@ -73,13 +73,18 @@ void add_q_8w_linear_node(
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
+ // Create temporary tensors to store the width packed versions of mat1 and out
+ TmpTensor mat1_tmp(
+ &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
+ TmpTensor out_tmp(
+ &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
- mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
+ mat1_W_packed = mat1_tmp;
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
- out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
+ out_W_packed = out_tmp;
}
ValueRef q_mat2 = prepack_standard(
graph, q_mat2_data, graph.storage_type_of(out), utils::kWidthPacked);
diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
index 2c46201351..6dcf2fc4f4 100644
--- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp
@@ -176,17 +176,32 @@ void resize_sdpa_out(
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
}
-void sdpa_with_kv_cache_impl(
- ComputeGraph& graph,
- const std::vector& args) {
+void update_cache_impl(ComputeGraph& graph, const std::vector& args) {
+ int arg_idx = 0;
+ const ValueRef value = args[arg_idx++];
+ const ValueRef cache = args[arg_idx++];
+ const ValueRef input_pos_symint = args[arg_idx++];
+ const ValueRef out = args[arg_idx++];
+
+ // Unused variables
+ (void)out;
+
+ VK_CHECK_COND(graph.size_at(-4, value) == 1);
+ VK_CHECK_COND(graph.size_at(-4, cache) == 1);
+ VK_CHECK_COND(
+ graph.size_at(-1, value) == graph.size_at(-1, cache));
+ VK_CHECK_COND(
+ graph.size_at(-2, value) == graph.size_at(-2, cache));
+
+ add_kv_cache_update_node(graph, input_pos_symint, value, cache);
+}
+
+void sdpa_impl(ComputeGraph& graph, const std::vector& args) {
int arg_idx = 0;
const ValueRef q_projected = args[arg_idx++];
- const ValueRef k_projected = args[arg_idx++];
- const ValueRef v_projected = args[arg_idx++];
- const ValueRef k_cache_data = args[arg_idx++];
- const ValueRef v_cache_data = args[arg_idx++];
+ const ValueRef k_cache = args[arg_idx++];
+ const ValueRef v_cache = args[arg_idx++];
const ValueRef input_pos_symint = args[arg_idx++];
- const ValueRef sequence_len = args[arg_idx++];
const ValueRef attn_mask = args[arg_idx++];
const ValueRef dropout_p = args[arg_idx++];
const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
// Output tensors
const ValueRef out = args[arg_idx++];
- // Unused variables
- (void)sequence_len;
-
// Batches must be 1
VK_CHECK_COND(graph.size_at(-4, q_projected) == 1);
- VK_CHECK_COND(graph.size_at(-4, k_projected) == 1);
- VK_CHECK_COND(graph.size_at(-4, v_projected) == 1);
+ VK_CHECK_COND(graph.size_at(-4, k_cache) == 1);
+ VK_CHECK_COND(graph.size_at(-4, v_cache) == 1);
// k and v projected must have the same shape
- VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
+ VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
// head dim must match between tensors
VK_CHECK_COND(
graph.size_at(-1, q_projected) ==
- graph.size_at(-1, k_projected));
+ graph.size_at(-1, k_cache));
// All tensors must have the packed dim be the width (head) dimension
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
- VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
- VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
+ VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
+ VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
// Some variables are not supported yet
VK_CHECK_COND(
graph.val_is_none(dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
graph.val_is_none(is_causal) || graph.extract_scalar(is_causal));
VK_CHECK_COND(graph.val_is_none(attn_mask));
- const ValueRef k_cache =
- prepack_standard_like(graph, k_cache_data, q_projected);
- const ValueRef v_cache =
- prepack_standard_like(graph, v_cache_data, q_projected);
-
const int32_t max_seq_len = graph.size_at(1, k_cache);
- add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
- add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
-
// Slice caches from 0 to input_pos + sequence_len
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
// Repeat interleave
const int64_t num_heads = graph.size_at(2, q_projected);
- const int64_t num_kv_heads = graph.size_at(2, k_projected);
+ const int64_t num_kv_heads = graph.size_at(2, k_cache);
const ValueRef num_repeats =
graph.add_scalar(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
}
+void sdpa_with_kv_cache_impl(
+ ComputeGraph& graph,
+ const std::vector& args) {
+ int arg_idx = 0;
+ const ValueRef q_projected = args[arg_idx++];
+ const ValueRef k_projected = args[arg_idx++];
+ const ValueRef v_projected = args[arg_idx++];
+ const ValueRef k_cache_data = args[arg_idx++];
+ const ValueRef v_cache_data = args[arg_idx++];
+ const ValueRef input_pos_symint = args[arg_idx++];
+ const ValueRef sequence_len = args[arg_idx++];
+ const ValueRef attn_mask = args[arg_idx++];
+ const ValueRef dropout_p = args[arg_idx++];
+ const ValueRef is_causal = args[arg_idx++];
+ const ValueRef scale = args[arg_idx++];
+
+ // Output tensors
+ const ValueRef out = args[arg_idx++];
+
+ (void)sequence_len;
+
+ const ValueRef k_cache =
+ prepack_standard_like(graph, k_cache_data, q_projected);
+ const ValueRef v_cache =
+ prepack_standard_like(graph, v_cache_data, q_projected);
+
+ update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
+ update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
+
+ sdpa_impl(
+ graph,
+ {q_projected,
+ k_cache,
+ v_cache,
+ input_pos_symint,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ out});
+}
+
REGISTER_OPERATORS {
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
+ VK_REGISTER_OP(update_cache.default, update_cache_impl);
+ VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
}
} // namespace vkcompute
diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
index 62922e8d9e..4bf73fad5a 100644
--- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
@@ -114,6 +114,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
"hardshrink"); \
}
+#define DEFINE_LEAKY_RELU_FN(op_name) \
+ void op_name(ComputeGraph& graph, const std::vector& args) { \
+ return add_unary_op_node( \
+ graph, \
+ args[0], \
+ get_val_or_inf(graph, args[1], /*neg slope*/ false), \
+ kDummyFloat, \
+ args[2], \
+ "leaky_relu"); \
+ }
+
void gelu(ComputeGraph& graph, const std::vector& args) {
// args[1] is the `approximate` string
// https://fburl.com/code/9omngmyo
@@ -137,6 +148,7 @@ DEFINE_RELU_FN(relu);
DEFINE_HARDSHRINK_FN(hardshrink);
DEFINE_ACTIVATION_FN(hardswish);
DEFINE_ACTIVATION_FN(hardsigmoid);
+DEFINE_LEAKY_RELU_FN(leaky_relu);
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
@@ -155,6 +167,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.hardshrink.default, hardshrink);
VK_REGISTER_OP(aten.hardswish.default, hardswish);
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
+ VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
}
} // namespace vkcompute
diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py
index 9cec4891c1..2130573c0c 100644
--- a/backends/vulkan/test/op_tests/cases.py
+++ b/backends/vulkan/test/op_tests/cases.py
@@ -1072,6 +1072,7 @@ def get_reduce_op_inputs():
"aten.cos.default",
"aten.hardswish.default",
"aten.hardsigmoid.default",
+ "aten.leaky_relu.default",
]
)
def get_unary_ops_inputs():
diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt
index ed8cf8d8e1..a21ef4f668 100644
--- a/backends/xnnpack/CMakeLists.txt
+++ b/backends/xnnpack/CMakeLists.txt
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
@@ -128,8 +129,17 @@ if(NOT CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$")
#
list(TRANSFORM _xnn_executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_executable(xnn_executor_runner ${_xnn_executor_runner__srcs})
+
+ if(EXECUTORCH_ENABLE_EVENT_TRACER)
+ if(EXECUTORCH_BUILD_DEVTOOLS)
+ list(APPEND xnn_executor_runner_libs etdump)
+ else()
+ message(SEND_ERROR "Use of 'EXECUTORCH_ENABLE_EVENT_TRACER' requires 'EXECUTORCH_BUILD_DEVTOOLS' to be enabled.")
+ endif()
+ endif()
+
target_link_libraries(
- xnn_executor_runner xnnpack_backend gflags portable_ops_lib
+ xnn_executor_runner gflags portable_ops_lib ${xnn_executor_runner_libs}
)
target_compile_options(xnn_executor_runner PUBLIC ${_common_compile_options})
endif()
diff --git a/backends/xnnpack/test/ops/test_add.py b/backends/xnnpack/test/ops/test_add.py
index 784a9d3bbf..29a87df130 100644
--- a/backends/xnnpack/test/ops/test_add.py
+++ b/backends/xnnpack/test/ops/test_add.py
@@ -7,7 +7,7 @@
import unittest
import torch
-from executorch.backends.xnnpack.test.tester import Tester
+from executorch.backends.xnnpack.test.tester import Quantize, Tester
class TestAdd(unittest.TestCase):
@@ -136,9 +136,12 @@ def test_qs8_add2(self):
def test_qs8_add3(self):
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
+ calibration_samples = [
+ (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) for _ in range(100)
+ ]
(
Tester(self.Add(), inputs)
- .quantize()
+ .quantize(Quantize(calibration_samples=calibration_samples))
.export()
.check_count({"torch.ops.aten.add.Tensor": 4})
.check(["torch.ops.quantized_decomposed"])
@@ -152,7 +155,7 @@ def test_qs8_add3(self):
)
.to_executorch()
.serialize()
- .run_method_and_compare_outputs()
+ .run_method_and_compare_outputs(num_runs=10, atol=0.02, rtol=0.02)
)
class AddRelu(torch.nn.Module):
diff --git a/backends/xnnpack/test/ops/test_conv1d.py b/backends/xnnpack/test/ops/test_conv1d.py
index 833ad69da6..b4c8c41492 100644
--- a/backends/xnnpack/test/ops/test_conv1d.py
+++ b/backends/xnnpack/test/ops/test_conv1d.py
@@ -13,7 +13,7 @@
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
-from executorch.backends.xnnpack.test.tester import RunPasses, Tester
+from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
@@ -98,9 +98,17 @@ def _test_conv1d(
stage=None,
skip_to_executorch=False,
):
+ calibration_samples = (
+ [tuple(torch.randn_like(inputs[i]) for i in range(len(inputs)))]
+ if quantized
+ else None
+ )
+
tester = (
(
- Tester(module, inputs, dynamic_shape).quantize()
+ Tester(module, inputs, dynamic_shape).quantize(
+ Quantize(calibration_samples=calibration_samples)
+ )
if quantized
else Tester(module, inputs)
)
@@ -114,7 +122,9 @@ def _test_conv1d(
# For some tests we want to skip to_executorch because otherwise it will require the
# quantized operators to be loaded and we don't want to do that in the test.
if not skip_to_executorch:
- tester.to_executorch().serialize().run_method_and_compare_outputs()
+ tester.to_executorch().serialize().run_method_and_compare_outputs(
+ num_runs=10, atol=0.02, rtol=0.02
+ )
def test_fp16_conv1d(self):
inputs = (torch.randn(2, 2, 4).to(torch.float16),)
diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py
index dc885135bb..7954425602 100644
--- a/backends/xnnpack/test/tester/tester.py
+++ b/backends/xnnpack/test/tester/tester.py
@@ -12,7 +12,7 @@
import sys
from abc import ABC, abstractmethod
from collections import Counter, OrderedDict
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
import torch
from executorch.backends.xnnpack._passes import XNNPACKPassManager
@@ -146,12 +146,14 @@ def __init__(
quantizer: Optional[Quantizer] = None,
quantization_config: Optional[QuantizationConfig] = None,
calibrate: bool = True,
+ calibration_samples: Optional[Sequence[Any]] = None,
):
self.quantizer = quantizer or XNNPACKQuantizer()
self.quantization_config = (
quantization_config or get_symmetric_quantization_config()
)
self.calibrate = calibrate
+ self.calibration_samples = calibration_samples
self.quantizer.set_global(self.quantization_config)
@@ -168,7 +170,11 @@ def run(
if self.calibrate:
# Calibrate prepared model to provide data to quantization observers.
- prepared(*inputs)
+ if self.calibration_samples is not None:
+ for inp in self.calibration_samples:
+ prepared(*inp)
+ else:
+ prepared(*inputs)
converted = convert_pt2e(prepared)
self.converted_graph = converted
diff --git a/codegen/tools/gen_selected_op_variants.py b/codegen/tools/gen_selected_op_variants.py
index da1c1215e2..95ae47f6f1 100644
--- a/codegen/tools/gen_selected_op_variants.py
+++ b/codegen/tools/gen_selected_op_variants.py
@@ -17,7 +17,7 @@
from torchgen.code_template import CodeTemplate
-ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))"""
+ops_and_dtypes_template_str = """((executorch::aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))"""
ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str)
selected_kernel_dtypes_h_template_str = """#pragma once
@@ -27,7 +27,7 @@
inline constexpr bool should_include_kernel_dtype(
const char *operator_name,
- exec_aten::ScalarType scalar_type
+ executorch::aten::ScalarType scalar_type
) {
return $body;
}
@@ -91,7 +91,8 @@ def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None:
dtype_set = set([x.split(";")[0] for x in tensor_meta])
dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set])
conditions = [
- "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list
+ "scalar_type == executorch::aten::ScalarType::" + x
+ for x in dtype_list
]
body_parts.append(
ops_and_dtypes_template.substitute(
diff --git a/codegen/tools/test/test_gen_selected_op_variants.py b/codegen/tools/test/test_gen_selected_op_variants.py
index 755b413cf5..e6f056e130 100644
--- a/codegen/tools/test/test_gen_selected_op_variants.py
+++ b/codegen/tools/test/test_gen_selected_op_variants.py
@@ -71,13 +71,13 @@ def test_generates_correct_header(self) -> None:
inline constexpr bool should_include_kernel_dtype(
const char *operator_name,
- exec_aten::ScalarType scalar_type
+ executorch::aten::ScalarType scalar_type
) {
- return ((exec_aten::string_view(operator_name).compare("add.out") == 0)
- && (scalar_type == exec_aten::ScalarType::Float || scalar_type == exec_aten::ScalarType::Int))
- || ((exec_aten::string_view(operator_name).compare("mul.out") == 0)
- && (scalar_type == exec_aten::ScalarType::Float))
- || ((exec_aten::string_view(operator_name).compare("sub.out") == 0)
+ return ((executorch::aten::string_view(operator_name).compare("add.out") == 0)
+ && (scalar_type == executorch::aten::ScalarType::Float || scalar_type == executorch::aten::ScalarType::Int))
+ || ((executorch::aten::string_view(operator_name).compare("mul.out") == 0)
+ && (scalar_type == executorch::aten::ScalarType::Float))
+ || ((executorch::aten::string_view(operator_name).compare("sub.out") == 0)
&& (true));
}
""",
@@ -124,7 +124,7 @@ def test_generates_correct_header(self) -> None:
inline constexpr bool should_include_kernel_dtype(
const char *operator_name,
- exec_aten::ScalarType scalar_type
+ executorch::aten::ScalarType scalar_type
) {
return true;
}
diff --git a/devtools/bundled_program/bundled_program.cpp b/devtools/bundled_program/bundled_program.cpp
index 54f84f6fef..1da42aa95d 100644
--- a/devtools/bundled_program/bundled_program.cpp
+++ b/devtools/bundled_program/bundled_program.cpp
@@ -23,10 +23,10 @@
#include
#include
-using exec_aten::ArrayRef;
-using exec_aten::Half;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
+using executorch::aten::ArrayRef;
+using executorch::aten::Half;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;
using ::executorch::runtime::Method;
@@ -67,16 +67,16 @@ TensorImpl impl_like(bundled_program_flatbuffer::Tensor* bundled_tensor) {
ScalarType scalar_type =
static_cast(bundled_tensor->scalar_type());
ssize_t dim = bundled_tensor->sizes()->size();
- exec_aten::SizesType* sizes = bundled_tensor->mutable_sizes()->data();
+ executorch::aten::SizesType* sizes = bundled_tensor->mutable_sizes()->data();
void* data = bundled_tensor->mutable_data()->data();
- exec_aten::DimOrderType* dim_order =
+ executorch::aten::DimOrderType* dim_order =
bundled_tensor->mutable_dim_order()->data();
// The strides of created tensorimpl will only be actually used when
// comparsion (`tensor_are_close` below). To eliminate the usage of memory
// allocator, here we set the initial strides as null and reconstruct the
// stride array as temporary varible when comparsion.
- exec_aten::StridesType* strides = nullptr;
+ executorch::aten::StridesType* strides = nullptr;
return TensorImpl(scalar_type, dim, sizes, data, dim_order, strides);
}
#endif
@@ -165,7 +165,7 @@ bool tensors_are_close(
// Contruct stride array for bundled tensor based on its dim order since
// strides of bundled_tensor in lean mode is null.
- exec_aten::StridesType strides[kMaxDim] = {0};
+ executorch::aten::StridesType strides[kMaxDim] = {0};
auto status = torch::executor::dim_order_to_stride(
bundled_tensor.sizes().data(),
bundled_tensor.dim_order().data(),
@@ -176,7 +176,7 @@ bool tensors_are_close(
// TODO(T132992348): support comparison between tensors of different strides
ET_CHECK_MSG(
- ArrayRef(strides, bundled_tensor.dim()) ==
+ ArrayRef(strides, bundled_tensor.dim()) ==
method_output_tensor.strides(),
"The two inputs of `tensors_are_close` function shall have same strides");
#endif
diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp
index c8e55b18d7..a34b5188c5 100644
--- a/devtools/etdump/etdump_flatcc.cpp
+++ b/devtools/etdump/etdump_flatcc.cpp
@@ -19,7 +19,7 @@
#include
-using ::exec_aten::Tensor;
+using ::executorch::aten::Tensor;
using ::executorch::runtime::AllocatorID;
using ::executorch::runtime::ArrayRef;
using ::executorch::runtime::ChainID;
@@ -37,27 +37,27 @@ namespace etdump {
namespace {
executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type(
- exec_aten::ScalarType tensor_scalar_type) {
+ executorch::aten::ScalarType tensor_scalar_type) {
switch (tensor_scalar_type) {
- case exec_aten::ScalarType::Byte:
+ case executorch::aten::ScalarType::Byte:
return executorch_flatbuffer_ScalarType_BYTE;
- case exec_aten::ScalarType::Char:
+ case executorch::aten::ScalarType::Char:
return executorch_flatbuffer_ScalarType_CHAR;
- case exec_aten::ScalarType::Short:
+ case executorch::aten::ScalarType::Short:
return executorch_flatbuffer_ScalarType_SHORT;
- case exec_aten::ScalarType::Float:
+ case executorch::aten::ScalarType::Float:
return executorch_flatbuffer_ScalarType_FLOAT;
- case exec_aten::ScalarType::Int:
+ case executorch::aten::ScalarType::Int:
return executorch_flatbuffer_ScalarType_INT;
- case exec_aten::ScalarType::Long:
+ case executorch::aten::ScalarType::Long:
return executorch_flatbuffer_ScalarType_LONG;
- case exec_aten::ScalarType::Double:
+ case executorch::aten::ScalarType::Double:
return executorch_flatbuffer_ScalarType_DOUBLE;
- case exec_aten::ScalarType::Bool:
+ case executorch::aten::ScalarType::Bool:
return executorch_flatbuffer_ScalarType_BOOL;
- case exec_aten::ScalarType::Bits16:
+ case executorch::aten::ScalarType::Bits16:
return executorch_flatbuffer_ScalarType_BITS16;
- case exec_aten::ScalarType::UInt16:
+ case executorch::aten::ScalarType::UInt16:
return executorch_flatbuffer_ScalarType_UINT16;
default:
ET_CHECK_MSG(
@@ -69,7 +69,7 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type(
etdump_Tensor_ref_t add_tensor_entry(
flatcc_builder_t* builder_,
- const exec_aten::Tensor& tensor,
+ const executorch::aten::Tensor& tensor,
long offset) {
etdump_Tensor_start(builder_);
@@ -508,7 +508,7 @@ void ETDumpGen::set_debug_buffer(Span buffer) {
debug_buffer_ = buffer;
}
-size_t ETDumpGen::copy_tensor_to_debug_buffer(exec_aten::Tensor tensor) {
+size_t ETDumpGen::copy_tensor_to_debug_buffer(executorch::aten::Tensor tensor) {
if (tensor.nbytes() == 0) {
return static_cast(-1);
}
@@ -536,7 +536,7 @@ void ETDumpGen::log_evalue(const EValue& evalue, LoggedEValueType evalue_type) {
switch (evalue.tag) {
case Tag::Tensor: {
- exec_aten::Tensor tensor = evalue.toTensor();
+ executorch::aten::Tensor tensor = evalue.toTensor();
long offset = copy_tensor_to_debug_buffer(tensor);
etdump_Tensor_ref_t tensor_ref =
add_tensor_entry(builder_, tensor, offset);
@@ -555,7 +555,8 @@ void ETDumpGen::log_evalue(const EValue& evalue, LoggedEValueType evalue_type) {
}
case Tag::ListTensor: {
- exec_aten::ArrayRef tensors = evalue.toTensorList();
+ executorch::aten::ArrayRef tensors =
+ evalue.toTensorList();
etdump_Tensor_vec_start(builder_);
for (size_t i = 0; i < tensors.size(); ++i) {
long offset = copy_tensor_to_debug_buffer(tensors[i]);
diff --git a/devtools/etdump/etdump_flatcc.h b/devtools/etdump/etdump_flatcc.h
index 4a818d18e5..d778106653 100644
--- a/devtools/etdump/etdump_flatcc.h
+++ b/devtools/etdump/etdump_flatcc.h
@@ -106,7 +106,7 @@ class ETDumpGen : public ::executorch::runtime::EventTracer {
virtual void log_intermediate_output_delegate(
const char* name,
::executorch::runtime::DebugHandle delegate_debug_index,
- const exec_aten::Tensor& output) override;
+ const executorch::aten::Tensor& output) override;
/**
* Log an intermediate tensor array output from a delegate.
@@ -114,7 +114,8 @@ class ETDumpGen : public ::executorch::runtime::EventTracer {
virtual void log_intermediate_output_delegate(
const char* name,
::executorch::runtime::DebugHandle delegate_debug_index,
- const ::executorch::runtime::ArrayRef output) override;
+ const ::executorch::runtime::ArrayRef output)
+ override;
/**
* Log an intermediate int output from a delegate.
@@ -157,7 +158,7 @@ class ETDumpGen : public ::executorch::runtime::EventTracer {
void check_ready_to_add_events();
int64_t create_string_entry(const char* name);
- size_t copy_tensor_to_debug_buffer(exec_aten::Tensor tensor);
+ size_t copy_tensor_to_debug_buffer(executorch::aten::Tensor tensor);
/**
* Templated helper function used to log various types of intermediate output.
diff --git a/devtools/etdump/tests/etdump_test.cpp b/devtools/etdump/tests/etdump_test.cpp
index f45652ab8f..664a5ee1a0 100644
--- a/devtools/etdump/tests/etdump_test.cpp
+++ b/devtools/etdump/tests/etdump_test.cpp
@@ -20,8 +20,8 @@
#include
#include
-using ::exec_aten::ScalarType;
-using ::exec_aten::Tensor;
+using ::executorch::aten::ScalarType;
+using ::executorch::aten::Tensor;
using ::executorch::etdump::ETDumpGen;
using ::executorch::etdump::ETDumpResult;
using ::executorch::runtime::AllocatorID;
@@ -205,12 +205,12 @@ TEST_F(ProfilerETDumpTest, DebugEvent) {
TEST_F(ProfilerETDumpTest, DebugEventTensorList) {
for (size_t i = 0; i < 2; i++) {
TensorFactory tf;
- exec_aten::Tensor storage[2] = {tf.ones({3, 2}), tf.ones({3, 2})};
+ executorch::aten::Tensor storage[2] = {tf.ones({3, 2}), tf.ones({3, 2})};
EValue evalue_1(storage[0]);
EValue evalue_2(storage[1]);
EValue* values_p[2] = {&evalue_1, &evalue_2};
- BoxedEvalueList a_box(values_p, storage, 2);
+ BoxedEvalueList a_box(values_p, storage, 2);
EValue evalue(a_box);
evalue.tag = Tag::ListTensor;
diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py
index 645cc5d537..df1d74c7fa 100644
--- a/devtools/visualization/__init__.py
+++ b/devtools/visualization/__init__.py
@@ -8,4 +8,5 @@
ModelExplorerServer,
SingletonModelExplorerServer,
visualize,
+ visualize_graph,
)
diff --git a/devtools/visualization/visualization_utils.py b/devtools/visualization/visualization_utils.py
index 4d520a6636..d21d11082a 100644
--- a/devtools/visualization/visualization_utils.py
+++ b/devtools/visualization/visualization_utils.py
@@ -6,9 +6,13 @@
import subprocess
import time
+from typing import Any, Callable, Type
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
+from executorch.exir.program._program import _update_exported_program_graph_module
+from torch._export.verifier import Verifier
from torch.export.exported_program import ExportedProgram
+from torch.fx import GraphModule
try:
from model_explorer import config, consts, visualize_from_config # type: ignore
@@ -27,7 +31,7 @@ class SingletonModelExplorerServer:
server: None | subprocess.Popen = None
num_open: int = 0
- wait_after_start = 2.0
+ wait_after_start = 3.0
def __init__(self, open_in_browser: bool = True, port: int | None = None):
if SingletonModelExplorerServer.server is None:
@@ -124,3 +128,29 @@ def visualize(
no_open_in_browser=no_open_in_browser,
**kwargs,
)
+
+
+def visualize_graph(
+ graph_module: GraphModule,
+ exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
+ reuse_server: bool = True,
+ no_open_in_browser: bool = False,
+ **kwargs,
+):
+ """Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
+ Also disables validating operators to allow visualizing graphs containing custom ops.
+
+ A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
+ """
+
+ class _any_op(Verifier):
+ dialect = "ANY_OP"
+
+ def allowed_op_types(self) -> tuple[Type[Any], ...]:
+ return (Callable,) # type: ignore
+
+ exported_program = _get_exported_program(exported_program)
+ exported_program = _update_exported_program_graph_module(
+ exported_program, graph_module, override_verifiers=[_any_op]
+ )
+ visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)
diff --git a/devtools/visualization/visualization_utils_test.py b/devtools/visualization/visualization_utils_test.py
index dafefa7dfd..d49c6d2f72 100644
--- a/devtools/visualization/visualization_utils_test.py
+++ b/devtools/visualization/visualization_utils_test.py
@@ -8,6 +8,7 @@
import pytest
import torch
+from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.xnnpack.test.tester import Tester
from executorch.devtools.visualization import (
@@ -15,8 +16,9 @@
SingletonModelExplorerServer,
visualization_utils,
visualize,
+ visualize_graph,
)
-from executorch.exir import ExportedProgram
+from executorch.exir import ExportedProgram, to_edge_transform_and_lower
try:
from model_explorer.config import ModelExplorerConfig # type: ignore
@@ -145,6 +147,17 @@ def test_visualize_to_executorch(server):
)
+def test_visualize_graph(server):
+ with server():
+ model = Linear(20, 30)
+ exported_program = torch.export.export(model, model.get_inputs())
+ exported_program = to_edge_transform_and_lower(
+ exported_program
+ ).exported_program()
+ modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module
+ visualize_graph(modified_gm, exported_program)
+
+
if __name__ == "__main__":
"""A test to run locally to make sure that the web browser opens up
automatically as intended.
@@ -158,3 +171,7 @@ def test_visualize_to_executorch(server):
test_visualize_to_edge(SingletonModelExplorerServer)
test_visualize_partition(SingletonModelExplorerServer)
test_visualize_to_executorch(SingletonModelExplorerServer)
+ test_visualize_graph(SingletonModelExplorerServer)
+
+ # Sleep to give the server time to load the last graph before killing it.
+ time.sleep(3.0)
diff --git a/docs/source/_static/img/et-logo.png b/docs/source/_static/img/et-logo.png
new file mode 100644
index 0000000000..b7995a5db7
Binary files /dev/null and b/docs/source/_static/img/et-logo.png differ
diff --git a/docs/source/_static/img/swiftpm_xcode1.png b/docs/source/_static/img/swiftpm_xcode1.png
index 61859c38fa..11b9c23782 100644
Binary files a/docs/source/_static/img/swiftpm_xcode1.png and b/docs/source/_static/img/swiftpm_xcode1.png differ
diff --git a/docs/source/apple-runtime.md b/docs/source/apple-runtime.md
index fe744add52..4114b78060 100644
--- a/docs/source/apple-runtime.md
+++ b/docs/source/apple-runtime.md
@@ -25,7 +25,7 @@ The prebuilt ExecuTorch runtime, backend, and kernels are available as a [Swift
#### Xcode
-In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the [ExecuTorch repo](https://github.com/pytorch/executorch) into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version in format "swiftpm-", (e.g. "swiftpm-0.4.0"), or a branch name in format "swiftpm-." (e.g. "swiftpm-0.4.0-20241201") for a nightly build on a specific date.
+In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the [ExecuTorch repo](https://github.com/pytorch/executorch) into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version in format "swiftpm-", (e.g. "swiftpm-0.5.0"), or a branch name in format "swiftpm-." (e.g. "swiftpm-0.5.0-20250130") for a nightly build on a specific date.

@@ -58,7 +58,7 @@ let package = Package(
],
dependencies: [
// Use "swiftpm-." branch name for a nightly build.
- .package(url: "https://github.com/pytorch/executorch.git", branch: "swiftpm-0.4.0")
+ .package(url: "https://github.com/pytorch/executorch.git", branch: "swiftpm-0.5.0")
],
targets: [
.target(
diff --git a/docs/source/getting-started-faqs.md b/docs/source/getting-started-faqs.md
new file mode 100644
index 0000000000..e103309f71
--- /dev/null
+++ b/docs/source/getting-started-faqs.md
@@ -0,0 +1,56 @@
+# FAQs and Common Issues
+
+This page summarizes frequently asked questions and provides guidance on issues that commonly occur when adopting ExecuTorch.
+
+If a specific issue is not covered here, consider searching for or creating an issue on GitHub under [Issues](https://github.com/pytorch/executorch/issues) or [Discussions](https://github.com/pytorch/executorch/discussions).
+
+## Export
+
+### Missing out variants: { _ }
+
+The model likely contains torch custom operators. Custom ops need an Executorch implementation and need to be loaded at export time. See the [ExecuTorch Custom Ops Documentation](https://pytorch.org/executorch/main/kernel-library-custom-aten-kernel.html#apis) for details on how to do this.
+
+### RuntimeError: PyTorch convert function for op _ not implemented
+
+The model likely contains an operator that is not yet supported on ExecuTorch. In this case, consider searching for or creating an issue on [GitHub](https://github.com/pytorch/executorch/issues).
+
+## Runtime
+
+ExecuTorch error codes are defined in [executorch/core/runtime/error.h](https://github.com/pytorch/executorch/blob/main/runtime/core/error.h).
+
+### Inference is Slow / Performance Troubleshooting
+
+If building the runtime from source, ensure that the build is done in release mode. For CMake builds, this can be done by passing `-DCMAKE_BUILD_TYPE=Release`.
+
+Ensure the model is delegated. If not targeting a specific accelerator, use the XNNPACK delegate for CPU performance. Undelegated operators will typically fall back to the ExecuTorch portable library, which is designed as a fallback, and is not intended for performance sensitive operators. To target XNNPACK, pass an `XnnpackPartitioner` to `to_edge_transform_and_lower`. See [Building and Running ExecuTorch with XNNPACK Backend](https://pytorch.org/executorch/main/tutorial-xnnpack-delegate-lowering.html) for more information.
+
+Thread count can have a significant impact on CPU performance. The optimal thread count may depend on the model and application. By default, ExecuTorch will currently use as many threads as there are cores. Consider setting the thread count to cores / 2, or just set to 4 on mobile CPUs.
+
+Thread count can be set with the following function. Ensure this is done prior to loading or running a model.
+```
+::executorch::extension::threadpool::get_threadpool()->_unsafe_reset_threadpool(num_threads);
+```
+
+For a deeper investgiation into model performance, ExecuTorch supports operator-level performance profiling. See [Using the ExecuTorch Developer Tools to Profile a Model](https://pytorch.org/executorch/main/tutorials/devtools-integration-tutorial.html) for more information.
+
+### Missing Logs
+
+ExecuTorch provides hooks to route runtime logs. By default, logs are sent to stdout/stderr, but users can override `et_pal_emit_log_message` to route logs to a custom destination. The Android and iOS extensions also provide out-of-box log routing to the appropriate platform logs. See [Runtime Platform Abstraction Layer (PAL)](https://pytorch.org/executorch/main/runtime-platform-abstraction-layer.html) for more information.
+
+### Error setting input: 0x10 / Attempted to resize a bounded tensor...
+
+This usually means the inputs provided do not match the shape of the example inputs used during model export. If the model is expected to handle varying size inputs (dynamic shapes), make sure the model export specifies the appropriate bounds. See [Expressing Dynamism](https://pytorch.org/docs/stable/export.html#expressing-dynamism) for more information on specifying dynamic shapes.
+
+### Error 0x14 (Operator Missing)
+
+This usually means that the selective build configuration is incorrect. Ensure that the operator library is generated from the current version of the model and the corresponding `et_operator_library` is a dependency of the app-level `executorch_generated_lib` and the generated lib is linked into the application.
+
+This can also occur if the ExecuTorch portable library does not yet have an implementation of the given ATen operator. In this case, consider search for or creating an issue on [GitHub](https://github.com/pytorch/executorch/issues).
+
+### Error 0x20 (Not Found)
+
+This error can occur for a few reasons, but the most common is a missing backend target. Ensure the appropriate backend target is linked. For XNNPACK, this is `xnnpack_backend`. If the backend is linked but is still not available, try linking with --whole-archive: `-Wl,--whole-archive libxnnpack_backend.a -Wl,--no-whole-archive`.
+
+### Duplicate Kernel Registration Abort
+
+This manifests as a crash call stack including ExecuTorch kernel registration and failing with an `et_pal_abort`. This typically means there are multiple `gen_operators_lib` targets linked into the applications. There must be only one generated operator library per target, though each model can have its own `gen_selected_ops/generate_bindings_for_kernels` call.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b3c69dd9e7..ea3cf5d827 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -87,7 +87,7 @@ Topics in this section will help you get started with ExecuTorch.
getting-started-setup
export-overview
runtime-build-and-cross-compilation
-
+ getting-started-faqs
.. toctree::
:glob:
diff --git a/docs/source/native-delegates-executorch-xnnpack-delegate.md b/docs/source/native-delegates-executorch-xnnpack-delegate.md
index de54de7706..6bfbfa6be3 100644
--- a/docs/source/native-delegates-executorch-xnnpack-delegate.md
+++ b/docs/source/native-delegates-executorch-xnnpack-delegate.md
@@ -70,7 +70,7 @@ Since weight packing creates an extra copy of the weights inside XNNPACK, We fre
When executing the XNNPACK subgraphs, we prepare the tensor inputs and outputs and feed them to the XNNPACK runtime graph. After executing the runtime graph, the output pointers are filled with the computed tensors.
#### **Profiling**
-We have enabled basic profiling for XNNPACK delegate that can be enabled with the following compiler flag `-DENABLE_XNNPACK_PROFILING`. With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](./tutorials/devtools-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information.
+We have enabled basic profiling for the XNNPACK delegate that can be enabled with the compiler flag `-DEXECUTORCH_ENABLE_EVENT_TRACER` (add `-DENABLE_XNNPACK_PROFILING` for additional details). With ExecuTorch's Developer Tools integration, you can also now use the Developer Tools to profile the model. You can follow the steps in [Using the ExecuTorch Developer Tools to Profile a Model](./tutorials/devtools-integration-tutorial) on how to profile ExecuTorch models and use Developer Tools' Inspector API to view XNNPACK's internal profiling information. An example implementation is available in the `xnn_executor_runner` (see [tutorial here](tutorial-xnnpack-delegate-lowering.md#profiling)).
[comment]: <> (TODO: Refactor quantizer to a more official quantization doc)
diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md
index 1c71a6ba80..d1148511c5 100644
--- a/docs/source/tutorial-xnnpack-delegate-lowering.md
+++ b/docs/source/tutorial-xnnpack-delegate-lowering.md
@@ -177,3 +177,6 @@ Now you should be able to find the executable built at `./cmake-out/backends/xnn
## Building and Linking with the XNNPACK Backend
You can build the XNNPACK backend [CMake target](https://github.com/pytorch/executorch/blob/main/backends/xnnpack/CMakeLists.txt#L83), and link it with your application binary such as an Android or iOS application. For more information on this you may take a look at this [resource](demo-apps-android.md) next.
+
+## Profiling
+To enable profiling in the `xnn_executor_runner` pass the flags `-DEXECUTORCH_ENABLE_EVENT_TRACER=ON` and `-DEXECUTORCH_BUILD_DEVTOOLS=ON` to the build command (add `-DENABLE_XNNPACK_PROFILING=ON` for additional details). This will enable ETDump generation when running the inference and enables command line flags for profiling (see `xnn_executor_runner --help` for details).
diff --git a/examples/apple/coreml/executor_runner/main.mm b/examples/apple/coreml/executor_runner/main.mm
index 35608dd092..1824458e34 100644
--- a/examples/apple/coreml/executor_runner/main.mm
+++ b/examples/apple/coreml/executor_runner/main.mm
@@ -249,8 +249,8 @@ Args parse_command_line_args(NSArray *args) {
}
Buffer buffer(tensor_meta->nbytes(), 0);
auto sizes = tensor_meta->sizes();
- exec_aten::TensorImpl tensor_impl(tensor_meta->scalar_type(), std::size(sizes), const_cast(sizes.data()), buffer.data());
- exec_aten::Tensor tensor(&tensor_impl);
+ executorch::aten::TensorImpl tensor_impl(tensor_meta->scalar_type(), std::size(sizes), const_cast(sizes.data()), buffer.data());
+ executorch::aten::Tensor tensor(&tensor_impl);
EValue input_value(std::move(tensor));
Error err = method.set_input(input_value, i);
if (err != Error::Ok) {
diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh
index 49729fdbf6..8b4cd275e4 100755
--- a/examples/arm/setup.sh
+++ b/examples/arm/setup.sh
@@ -7,28 +7,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-set -eu
-
-if [[ "${1:-'.'}" == "-h" || "${#}" -gt 2 ]]; then
- echo "Usage: $(basename $0) <--i-agree-to-the-contained-eula> [path-to-a-scratch-dir]"
- echo "Supplied args: $*"
- exit 1
-fi
-
-
-########
-### Helper functions
-########
-ARCH="$(uname -m)"
-OS="$(uname -s)"
-
-
+set -u
########
### Hardcoded constants
########
script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
et_dir=$(realpath $script_dir/../..)
+ARCH="$(uname -m)"
+OS="$(uname -s)"
if [[ "${ARCH}" == "x86_64" ]]; then
# FVPs
@@ -78,39 +65,40 @@ tosa_reference_model_rev="v0.80.1"
# vela
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
-vela_rev="fc970e3da72e5f6930b840b357684126602b3126"
+vela_rev="e131bf4f528f0d461868229972e07f371dcbc881"
-########
-### Mandatory user args
-########
-eula_acceptance="${1:-'.'}"
-if [[ "${eula_acceptance}" != "--i-agree-to-the-contained-eula" ]]; then
- if [[ ${ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA} != "True" ]]; then
- echo "Must pass first positional argument '--i-agree-to-the-contained-eula' to agree to EULA associated with downloading the FVP. Exiting!"
- exit 1
- else
- echo "Arm EULA for FVP agreed to with ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True environment variable"
- fi
-else
- shift; # drop this arg
-fi
########
### Optional user args
########
-root_dir=${1:-"${script_dir}/ethos-u-scratch"}
+root_dir=${2:-"${script_dir}/ethos-u-scratch"}
mkdir -p ${root_dir}
root_dir=$(realpath ${root_dir})
+setup_path_script="${root_dir}/setup_path.sh"
+
########
### Functions
########
function setup_fvp() {
+
+ # Mandatory user arg --i-agree-to-the-contained-eula
+ eula_acceptance="${1:-'.'}"
+ if [[ "${eula_acceptance}" != "--i-agree-to-the-contained-eula" ]]; then
+ if [[ ${ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA} != "True" ]]; then
+ echo "Must pass first positional argument '--i-agree-to-the-contained-eula' to agree to EULA associated with downloading the FVP. Exiting!"
+ exit 1
+ else
+ echo "Arm EULA for FVP agreed to with ARM_FVP_INSTALL_I_AGREE_TO_THE_CONTAINED_EULA=True environment variable"
+ fi
+ else
+ shift; # drop this arg
+ fi
if [[ "${OS}" != "Linux" ]]; then
echo "[${FUNCNAME[0]}] Warning: FVP only supported with Linux OS, skipping FVP setup..."
echo "[${FUNCNAME[0]}] Warning: For MacOS, using https://github.com/Arm-Examples/FVPs-on-Mac is recommended."
- echo "[${FUNCNAME[0]}] Warning: Follow the instructions and make sure the path is set correctly."
+ echo "[${FUNCNAME[0]}] Warning: Follow the instructions and make sure the path is set correctly."
return 1
fi
@@ -148,17 +136,7 @@ function setup_fvp() {
exit 1
;;
esac
-
- model_dir_variable=${fvp}_model_dir
- fvp_model_dir=${!model_dir_variable}
- fvp_bin_path="$(cd models/${fvp_model_dir} && pwd)"
- export PATH=${PATH}:${fvp_bin_path}
-
- echo "export PATH=\${PATH}:${fvp_bin_path}" >> ${setup_path_script}
done
-
- # Fixup for Corstone-320 python dependency
- echo "export LD_LIBRARY_PATH=${root_dir}/FVP-corstone320/python/lib/" >> ${setup_path_script}
}
function setup_toolchain() {
@@ -173,10 +151,6 @@ function setup_toolchain() {
echo "[${FUNCNAME[0]}] Installing toolchain ..."
rm -rf "${toolchain_dir}"
tar xf "${toolchain_dir}.tar.xz"
- toolchain_bin_path="$(cd ${toolchain_dir}/bin && pwd)"
- export PATH=${PATH}:${toolchain_bin_path}
- hash arm-none-eabi-gcc
- echo "export PATH=\${PATH}:${toolchain_bin_path}" >> ${setup_path_script}
}
function setup_tosa_reference_model() {
@@ -188,48 +162,81 @@ function setup_tosa_reference_model() {
}
function setup_vela() {
- #
- # Prepare the Vela compiler for AoT to Ethos-U compilation
- #
pip install ethos-u-vela@git+${vela_repo_url}@${vela_rev}
}
+function setup_path() {
+ echo $setup_path_script
+}
+
+function create_setup_path(){
+ echo "" > "${setup_path_script}"
+ fvps=("corstone300" "corstone320")
+ for fvp in "${fvps[@]}"; do
+ model_dir_variable=${fvp}_model_dir
+ fvp_model_dir=${!model_dir_variable}
+ fvp_bin_path="${root_dir}/FVP-${fvp}/models/${fvp_model_dir}"
+ echo "export PATH=\${PATH}:${fvp_bin_path}" >> ${setup_path_script}
+ done
+
+ # Fixup for Corstone-320 python dependency
+ echo "export LD_LIBRARY_PATH=${root_dir}/FVP-corstone320/python/lib/" >> ${setup_path_script}
+
+ toolchain_bin_path="$(cd ${toolchain_dir}/bin && pwd)"
+ echo "export PATH=\${PATH}:${toolchain_bin_path}" >> ${setup_path_script}
+
+ echo "hash FVP_Corstone_SSE-300_Ethos-U55" >> ${setup_path_script}
+ echo "hash FVP_Corstone_SSE-300_Ethos-U65" >> ${setup_path_script}
+ echo "hash FVP_Corstone_SSE-320" >> ${setup_path_script}
+}
+
########
### main
########
-# do basic checks
-# Make sure we are on a supported platform
-if [[ "${ARCH}" != "x86_64" ]] && [[ "${ARCH}" != "aarch64" ]] \
- && [[ "${ARCH}" != "arm64" ]]; then
- echo "[main] Error: only x86-64 & aarch64 architecture is supported for now!"
- exit 1
-fi
+# Only run this if script is executed, not if it is sourced
+(return 0 2>/dev/null) && is_script_sourced=1 || is_script_sourced=0
+if [[ $is_script_sourced -eq 0 ]]
+ then
+ set -e
+ if [[ "${ARCH}" != "x86_64" ]] && [[ "${ARCH}" != "aarch64" ]] \
+ && [[ "${ARCH}" != "arm64" ]]; then
+ echo "[main] Error: only x86-64 & aarch64 architecture is supported for now!"
+ exit 1
+ fi
-cd "${script_dir}"
+ # Make sure we are on a supported platform
+ if [[ "${1:-'.'}" == "-h" || "${#}" -gt 2 ]]; then
+ echo "Usage: $(basename $0) <--i-agree-to-the-contained-eula> [path-to-a-scratch-dir]"
+ echo "Supplied args: $*"
+ exit 1
+ fi
-# Setup the root dir
-cd "${root_dir}"
-echo "[main] Using root dir ${root_dir}"
+ cd "${script_dir}"
-setup_path_script="${root_dir}/setup_path.sh"
-echo "" > "${setup_path_script}"
+ # Setup the root dir
+ cd "${root_dir}"
+ echo "[main] Using root dir ${root_dir}"
+
+ # Import utils
+ source $et_dir/backends/arm/scripts/utils.sh
-# Import utils
-source $et_dir/backends/arm/scripts/utils.sh
+ # Setup FVP
+ setup_fvp ${1:-'.'}
-# Setup toolchain
-setup_toolchain
+ # Setup toolchain
+ setup_toolchain
-# Setup the tosa_reference_model
-setup_tosa_reference_model
+ # Create new setup_path script only if fvp and toolchain setup went well.
+ create_setup_path
-# Setup vela and patch in codegen fixes
-setup_vela
+ # Setup the tosa_reference_model
+ setup_tosa_reference_model
-# Setup FVP
-setup_fvp
+ # Setup vela and patch in codegen fixes
+ setup_vela
-echo "[main] update path by doing 'source ${setup_path_script}'"
+ echo "[main] update path by doing 'source ${setup_path_script}'"
-echo "[main] success!"
-exit 0
+ echo "[main] success!"
+ exit 0
+fi
diff --git a/examples/cadence/operators/facto_util.py b/examples/cadence/operators/facto_util.py
index 304b1c7e72..5e6a58ce9f 100644
--- a/examples/cadence/operators/facto_util.py
+++ b/examples/cadence/operators/facto_util.py
@@ -22,7 +22,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float]),
- cp.Rank.Le(lambda deps: 2**3),
+ cp.Rank.Le(lambda deps: 2**2),
+ cp.Value.Ge(lambda deps, dtype, struct: -2),
+ cp.Value.Le(lambda deps, dtype, struct: 2),
+ ]
+ )
+ case "mean.dim":
+ tensor_constraints.extend(
+ [
+ cp.Dtype.In(lambda deps: [torch.float]),
+ cp.Rank.Le(lambda deps: 2**2),
]
)
case "exp.default":
@@ -86,8 +95,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
cp.Value.Le(lambda deps, dtype: 2),
]
)
+ elif in_spec.type.is_scalar_type():
+ spec.inspec[index].constraints.extend(
+ [
+ cp.Dtype.In(lambda deps: apply_scalar_contraints(op_name)),
+ ]
+ )
elif in_spec.type.is_tensor():
spec.inspec[index].constraints.extend(tensor_constraints)
+ elif in_spec.type.is_dim_list():
+ spec.inspec[index].constraints.extend(
+ [
+ cp.Length.Ge(lambda deps: 1),
+ cp.Optional.Eq(lambda deps: False),
+ ]
+ )
+ elif in_spec.type.is_bool():
+ spec.inspec[index].constraints.extend(
+ [
+ cp.Dtype.In(lambda deps: [torch.bool]),
+ ]
+ )
return [
(posargs, inkwargs)
diff --git a/examples/cadence/operators/test_g3_ops.py b/examples/cadence/operators/test_g3_ops.py
index 158e13d389..58433cc739 100644
--- a/examples/cadence/operators/test_g3_ops.py
+++ b/examples/cadence/operators/test_g3_ops.py
@@ -259,6 +259,35 @@ def test_g3__softmax_out(
self.run_and_verify(model, (inputs,))
+ # pyre-ignore[16]: Module `parameterized.parameterized` has no attribute `expand`.
+ @parameterized.expand([*facto_util.facto_testcase_gen("mean.dim")])
+ def test_g3_mean_dim_out(
+ self,
+ posargs: List[int],
+ inkwargs: OrderedDict[str, str],
+ ) -> None:
+ class Meandim(nn.Module):
+ def forward(
+ self,
+ x: torch.Tensor,
+ dim_list: Tuple[int],
+ keepdim: bool,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ return torch.ops.aten.mean.dim(
+ x,
+ dim_list,
+ keepdim,
+ dtype=dtype,
+ )
+
+ model = Meandim()
+
+ self.run_and_verify(
+ model,
+ inputs=tuple(posargs),
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo.xcodeproj/project.pbxproj b/examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo.xcodeproj/project.pbxproj
index aff4c7a74b..f08d61396d 100644
--- a/examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo.xcodeproj/project.pbxproj
+++ b/examples/demo-apps/apple_ios/ExecuTorchDemo/ExecuTorchDemo.xcodeproj/project.pbxproj
@@ -806,7 +806,7 @@
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
- branch = "swiftpm-0.4.0.20241120";
+ branch = "swiftpm-0.5.0.20250130";
kind = branch;
};
};
diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
index 0145d7745f..2cc9380879 100644
--- a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
+++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj
@@ -808,7 +808,7 @@
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch";
requirement = {
- branch = "swiftpm-0.4.0.20241120";
+ branch = "swiftpm-0.5.0.20250130";
kind = branch;
};
};
diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md
index bfe66bbd4e..e1a1530acf 100644
--- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md
+++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md
@@ -76,7 +76,7 @@ sudo /Applications/CMake.app/Contents/bin/cmake-gui --install
The prebuilt ExecuTorch runtime, backend, and kernels are available as a Swift PM package.
### Xcode
-Open the project in Xcode.In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the ExecuTorch repo into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version, e.g., “0.4.0”, or just use the “latest” branch name for the latest stable build.
+Open the project in Xcode.In Xcode, go to `File > Add Package Dependencies`. Paste the URL of the ExecuTorch repo into the search bar and select it. Make sure to change the branch name to the desired ExecuTorch version, e.g., “swiftpm-0.5.0”, or a branch name in format "swiftpm-." (e.g. "swiftpm-0.5.0-20250130") for a nightly build on a specific date.
Note: If you're running into any issues related to package dependencies, quit Xcode entirely, delete the whole executorch repo, clean the caches by running the command below in terminal and clone the repo again.
diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md
index b357628042..784ebe50f8 100644
--- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md
+++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md
@@ -130,9 +130,9 @@ While we recommended using the latest prebuilt package pre-configured with the X
Go to Project Navigator, click on LLaMA. `Project --> LLaMA --> Package Dependencies`, and update the package dependencies to any of the available options below:
-- Branch --> swiftpm-0.4.0.20241120 (amend to match the latest nightly build)
-- Branch --> 0.4.0
-- Branch --> 0.3.0
+- Branch --> swiftpm-0.5.0.20250130 (amend to match the latest nightly build)
+- Branch --> swiftpm-0.5.0
+- Branch --> swiftpm-0.4.0
### 2.2 Manually build the package locally and link them
diff --git a/examples/demo-apps/react-native/rnllama/ios/rnllama.xcodeproj/project.pbxproj b/examples/demo-apps/react-native/rnllama/ios/rnllama.xcodeproj/project.pbxproj
index 489fa4d9f7..1a58797064 100644
--- a/examples/demo-apps/react-native/rnllama/ios/rnllama.xcodeproj/project.pbxproj
+++ b/examples/demo-apps/react-native/rnllama/ios/rnllama.xcodeproj/project.pbxproj
@@ -947,7 +947,7 @@
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/pytorch/executorch.git";
requirement = {
- branch = "swiftpm-0.4.0.20241120";
+ branch = "swiftpm-0.5.0.20250130";
kind = branch;
};
};
diff --git a/examples/models/deepseek-r1-distill-llama-8B/README.md b/examples/models/deepseek-r1-distill-llama-8B/README.md
new file mode 100644
index 0000000000..3a7a723c73
--- /dev/null
+++ b/examples/models/deepseek-r1-distill-llama-8B/README.md
@@ -0,0 +1,72 @@
+# Summary
+This example demonstrates how to run [Deepseek R1 Distill Llama 8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) 3.8B model via ExecuTorch. The architecture of this distilled model is exactly the same as Llama and thus all the instructions mentioned in the [Llama README](../llama/README.md) apply as is.
+
+# Instructions
+## Step 1: Setup
+1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_executorch.sh`
+
+2. Run the installation step for Llama specific requirements
+```
+./examples/models/llama/install_requirements.sh
+```
+
+## Step 2: Prepare and run the model
+1. Download the model
+```
+pip install -U "huggingface_hub[cli]"
+huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local-dir /target_dir/DeepSeek-R1-Distill-Llama-8B --local-dir-use-symlinks False
+```
+
+2. Download the [tokenizer.model](https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/original/tokenizer.model) from the Llama3.1 repo which will be needed later on when running the model using the runtime.
+
+3. Convert the model to pth file.
+```
+pip install torchtune
+```
+
+Run this python code:
+```
+from torchtune.models import convert_weights
+from torchtune.training import FullModelHFCheckpointer
+import torch
+
+# Convert from safetensors to TorchTune. Suppose the model has been downloaded from Hugging Face
+checkpointer = FullModelHFCheckpointer(
+ checkpoint_dir='/target_dir/DeepSeek-R1-Distill-Llama-8B ',
+ checkpoint_files=['model-00001-of-000002.safetensors', 'model-00002-of-000002.safetensors'],
+ output_dir='/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/' ,
+ model_type='LLAMA3' # or other types that TorchTune supports
+)
+
+print("loading checkpoint")
+sd = checkpointer.load_checkpoint()
+
+# Convert from TorchTune to Meta (PyTorch native)
+sd = convert_weights.tune_to_meta(sd['model'])
+
+print("saving checkpoint")
+torch.save(sd, "/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth")
+```
+
+4. Download and save the params.json file
+```
+wget https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/original/params.json -o /tmp/params.json
+```
+
+5. Generate a PTE file for use with the Llama runner.
+```
+python -m examples.models.llama.export_llama \
+ --checkpoint /tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \
+ -p /tmp/params.json \
+ -kv \
+ --use_sdpa_with_kv_cache \
+ -X \
+ -qmode 8da4w \
+ --group_size 128 \
+ -d fp16 \
+ --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
+ --embedding-quantize 4,32 \
+ --output_name="DeepSeek-R1-Distill-Llama-8B.pte"
+```
+
+6. Run the model on your desktop for validation or integrate with iOS/Android apps. Instructions for these are available in the Llama [README](../llama/README.md) starting at Step 3.
diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS
index 4fe7f6cc2b..f6b78e876c 100644
--- a/examples/models/llama/TARGETS
+++ b/examples/models/llama/TARGETS
@@ -14,6 +14,8 @@ runtime.python_library(
srcs = [
"llama_transformer.py",
"rope.py",
+ "attention.py",
+ "model_args.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama",
diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py
new file mode 100644
index 0000000000..ec55f2f1ee
--- /dev/null
+++ b/examples/models/llama/attention.py
@@ -0,0 +1,255 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional, Tuple, Type, TypedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from executorch.examples.models.llama.model_args import ModelArgs
+from executorch.examples.models.llama.rope import Rope
+
+
+class ForwardOptions(TypedDict, total=False):
+ """Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""
+
+ mask: Optional[torch.Tensor]
+ input_pos: Optional[torch.Tensor]
+ in_cache_state: Optional[Any]
+ out_cache_state: Optional[Any]
+
+
+class Attention(nn.Module, ABC):
+ """Abstract base class for attention mechanisms with unified interface."""
+
+ @abstractmethod
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ **kwargs: ForwardOptions,
+ ) -> Tuple[torch.Tensor, Optional[Any]]:
+ """Forward pass for attention mechanism.
+
+ Args:
+ x: Input tensor of shape (batch_size, seq_len, dim)
+ freqs_cos, freqs_sin: Rotary position embedding frequencies
+ ForwardOptions: grouped optional args
+
+ Returns:
+ Tuple of (output tensor, updated cache state)
+ """
+ pass
+
+
+ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
+
+
+def register_attention(name: str):
+ """Decorator to register attention classes"""
+
+ def decorator(cls: Type[Attention]):
+ ATTENTION_REGISTRY[name.lower()] = cls
+ return cls
+
+ return decorator
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self,
+ max_batch_size: int,
+ max_context_length: int,
+ n_heads: int,
+ head_dim: int,
+ enable_dynamic_shape: bool,
+ dtype=torch.float32,
+ ):
+ super().__init__()
+ self.max_context_length = max_context_length
+ cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
+
+ self.max_batch_size = max_batch_size
+ self.n_heads = n_heads
+ self.head_dim = head_dim
+ self.enable_dynamic_shape = enable_dynamic_shape
+ self.register_buffer(
+ "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
+ )
+ self.register_buffer(
+ "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
+ )
+
+ def update(
+ self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # input_pos: [S], k_val: [B, H, S, D]
+ if self.enable_dynamic_shape:
+ start_pos = input_pos[0].item()
+ torch._check_is_size(start_pos)
+ torch._check(start_pos < self.max_context_length)
+ dim_to_slice = 2
+ seq_length = k_val.size(dim_to_slice)
+ # Replace the entry in the cache for this token
+ # The following lines are equivalent to:
+ # cache_k[:bsz, start_pos : start_pos + seqlen] = xk
+ # cache_v[:bsz, start_pos : start_pos + seqlen] = xv
+ # when dim_to_slice is 1
+ # We use .narrow() here to make the compiler happy
+ # pyre-ignore: Incompatible parameter type [6]
+ narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
+ # pyre-ignore: Incompatible parameter type [6]
+ narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
+
+ narrowed_k.copy_(k_val)
+ narrowed_v.copy_(v_val)
+ return self.k_cache, self.v_cache
+ else:
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+class SDPA(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ head_dim: int,
+ n_rep: int,
+ max_context_len: int,
+ enable_dynamic_shape: bool,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.head_dim = head_dim
+ self.n_rep = n_rep
+ self.max_context_len = max_context_len
+ self.enable_dynamic_shape = enable_dynamic_shape
+
+ def forward(
+ self,
+ input_pos: torch.Tensor,
+ q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
+ k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
+ v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
+ bsz,
+ seqlen,
+ mask: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.enable_dynamic_shape:
+ start_pos = input_pos[-1].item()
+ torch._check_is_size(start_pos)
+ torch._check(start_pos < self.max_context_len)
+ seq_length = q.size(2)
+ # pyre-ignore: Incompatible parameter type [6]
+ attn_mask = mask.narrow(0, start_pos, seq_length)
+ else:
+ attn_mask = mask[None, None, input_pos]
+
+ # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
+ # can natively support GQA now. But needs enable_gqa=True
+ k = k.repeat_interleave(self.n_rep, dim=1)
+ v = v.repeat_interleave(self.n_rep, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
+
+ return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+
+@register_attention("mha")
+class AttentionMHA(Attention):
+ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
+ super().__init__()
+ self.use_kv_cache = args.use_kv_cache
+ self.n_heads = args.n_heads
+ self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
+ assert self.n_heads % self.n_kv_heads == 0
+ model_parallel_size = 1
+ self.n_local_heads = self.n_heads // model_parallel_size
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
+ self.head_dim = args.head_dim
+ self.max_batch_size = args.max_batch_size
+ self.max_context_len = args.max_context_len
+ self.dim = args.dim
+ self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
+ self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
+ self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
+
+ self.layer_id = layer_id
+
+ self.rope = rope
+
+ causal_mask = torch.tril(
+ torch.ones(
+ self.max_context_len,
+ self.max_context_len,
+ dtype=torch.bool,
+ device="cpu",
+ )
+ )
+ self.register_buffer("mask", causal_mask, persistent=False)
+
+ if self.use_kv_cache:
+ self.kv_cache = KVCache(
+ args.max_batch_size,
+ args.max_context_len,
+ self.n_kv_heads,
+ self.head_dim,
+ args.enable_dynamic_shape,
+ )
+ self.SDPA = SDPA(
+ dim=self.n_local_heads * self.head_dim,
+ head_dim=self.head_dim,
+ n_rep=self.n_rep,
+ max_context_len=self.max_context_len,
+ enable_dynamic_shape=args.enable_dynamic_shape,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ **kwargs: ForwardOptions,
+ ) -> Tuple[torch.Tensor, Optional[Any]]:
+ input_pos = kwargs.get("input_pos")
+ bsz, seqlen, _ = x.shape
+
+ # QKV
+ q, k, v = self.wq(x), self.wk(x), self.wv(x)
+ # We need view_copy elimination
+ q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+
+ # RoPE relative positional embeddings
+ q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
+
+ q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ if self.use_kv_cache:
+ assert input_pos is not None
+ k, v = self.kv_cache.update(input_pos, k, v)
+ output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
+ return self.wo(output)
+
+ # grouped multiquery attention: expand out keys and values
+ k = k.repeat_interleave(self.n_rep, dim=1)
+ v = v.repeat_interleave(self.n_rep, dim=1)
+
+ assert hasattr(self, "mask")
+
+ mask = self.mask[:seqlen, :seqlen]
+
+ output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
+
+ output = self.wo(output)
+
+ return output
diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py
index c25dce6ffc..618c74e870 100644
--- a/examples/models/llama/export_llama_lib.py
+++ b/examples/models/llama/export_llama_lib.py
@@ -21,6 +21,8 @@
import pkg_resources
import torch
+
+from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import get_delegation_info
from executorch.devtools.etrecord import generate_etrecord
@@ -335,6 +337,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="maximum length sequence to evaluate",
)
+ parser.add_argument(
+ "--max_context_length",
+ type=int,
+ default=128,
+ help="maximum length of context for model to remember",
+ )
+
parser.add_argument("-2", "--fairseq2", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
@@ -579,6 +588,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
+ max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
@@ -637,6 +647,11 @@ def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
"""
+
+ if args.max_context_length < args.max_seq_length:
+ raise ValueError(
+ f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
+ )
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
raise ValueError(
"Dynamic shape is not supported with coreml, MPS or qnn backends."
@@ -662,6 +677,7 @@ def _validate_args(args):
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
# export_to_edge
@@ -713,6 +729,10 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
)
modelname = f"vulkan_{modelname}"
+ # Need to remove asserts from the graph to prevent graph breaks
+ # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
+ remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
+
if args.mps:
partitioners.append(get_mps_partitioner(args.use_kv_cache))
modelname = f"mps_{modelname}"
@@ -760,13 +780,13 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
- (atten.max_batch_size, atten.max_seq_len, atten.head_dim)
+ (atten.max_batch_size, atten.max_context_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
- atten.max_seq_len,
+ atten.max_context_len,
atten.n_kv_heads,
atten.head_dim,
)
@@ -861,6 +881,7 @@ def _load_llama_model_metadata(
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
max_seq_len: int,
+ max_context_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
@@ -870,6 +891,7 @@ def _load_llama_model_metadata(
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": max_seq_len,
+ "get_max_context_len": max_context_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
@@ -904,6 +926,7 @@ def _load_llama_model(
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
+ max_context_len: int = 128,
input_prune_map_path: Optional[str] = None,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
@@ -948,6 +971,7 @@ def _load_llama_model(
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
+ max_context_len=max_context_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
@@ -1006,10 +1030,13 @@ def _load_llama_model(
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_seq_len,
- # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
+ # pyre-fixme[6]: For 6th argument expected `ModelArgs` but got
+ # `Union[Tensor, Module]`.
+ model.max_context_len,
+ # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
- # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
+ # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor,
# Module]`.
model.vocab_size,
metadata_str,
diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py
index d5661ae400..08526dde19 100644
--- a/examples/models/llama/llama_transformer.py
+++ b/examples/models/llama/llama_transformer.py
@@ -7,19 +7,16 @@
# Please refer to README.md in the same folder for more information.
-from dataclasses import dataclass
-from functools import partial
-from typing import Dict, Optional, Tuple
+from typing import Optional
import torch
import torch.nn.functional as F
-from executorch.examples.models.llama.rope import (
- hf_apply_rotary_emb,
- hf_precompute_freqs_cis,
- precompute_freqs_cis,
- RotaryEmbedding,
-)
+from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
+
+from executorch.examples.models.llama.model_args import ModelArgs
+
+from executorch.examples.models.llama.rope import Rope
from torch import nn
@@ -71,359 +68,6 @@ def forward(self, x):
return output * self.weight
-def find_multiple(n: int, k: int) -> int:
- if n % k == 0:
- return n
- return n + k - (n % k)
-
-
-@dataclass
-class ModelArgs:
- dim: int = 4096
- n_layers: int = 32
- n_heads: int = 32
- n_kv_heads: Optional[int] = None
- vocab_size: int = -1 # defined later by tokenizer
- hidden_dim: Optional[int] = None
- head_dim: Optional[int] = None # Optional customized head_dim
- multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
- ffn_dim_multiplier: Optional[float] = None
- norm_eps: float = 1e-5
- max_batch_size: int = 32
- max_seq_len: int = 2048
- moe: bool = False # True to enable the MoE (Mixture of Experts)
- num_experts: int = 8 # Number of experts
- num_activated_experts: int = 2 # Number of experts to activate
- use_kv_cache: bool = False # Use key/value cache
- use_sdpa_with_kv_cache_op: bool = (
- False # Use custom sdpa op that updates kv cache in-place
- )
- # Generate logits for all inputs. When it's True, it would take big memory usage
- # at runtime. Enable it only necessary (e.g., use perplexity tools that requires
- # logits for all input tokens.)
- generate_full_logits: bool = False
- enable_dynamic_shape: bool = False # export model with dynamic shape support
- # A dictionary mapping from pruned token-id to original token-id
- input_prune_map: Optional[Dict[int, int]] = None
- # A dictionary mapping from pruned token-id to original token-id
- output_prune_map: Optional[Dict[int, int]] = None
- use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
- rope_theta: Optional[float] = (
- None # The official name to override self.rope_freq_base.
- )
- rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
- use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
- rope_scale_factor: int = 8
- # Additional Model Metadata needed at runtime
- bos_idx: int = 1
- eos_idx: int = 3
- bos_count: int = -1 # i.e., a single EOS is used as BOS
- eos_count: int = 2
-
- quantization_args: Optional[dict] = None
- lora_args: Optional[dict] = None
-
- def __post_init__(self):
- if self.n_kv_heads is None:
- self.n_kv_heads = self.n_heads
-
- # rope_theta overrides rope_freq_base since it's the official name.
- if self.rope_theta is not None:
- self.rope_freq_base = self.rope_theta
-
- if self.use_sdpa_with_kv_cache_op:
- assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"
-
- if self.hidden_dim is None:
- # If hidden_dim is not explicitly set in the ModelArgs,
- # then calculate implicitly based on dim and also multiple of `args.multiple_of`
- multiple_of = self.multiple_of
- hidden_dim = 4 * self.dim
- hidden_dim = int(2 * hidden_dim / 3)
- if self.ffn_dim_multiplier is not None:
- hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
- self.hidden_dim = find_multiple(hidden_dim, multiple_of)
-
- if self.head_dim is None:
- self.head_dim = self.dim // self.n_heads
-
-
-class Rope(torch.nn.Module):
- def __init__(self, params: ModelArgs):
- super().__init__()
- self.params = params
- if self.params.use_hf_rope:
- self.precompute_freqs_cis = hf_precompute_freqs_cis
- else:
- self.precompute_freqs_cis = partial(
- precompute_freqs_cis,
- use_scaled=self.params.use_scaled_rope,
- scale_factor=self.params.rope_scale_factor,
- )
- freqs_cos, freqs_sin = self.precompute_freqs_cis(
- self.params.head_dim,
- (
- self.params.max_seq_len # Normal llama2.
- if self.params.ffn_dim_multiplier is None
- else self.params.max_seq_len * 2 # Sharded checkpoint.
- ),
- self.params.rope_freq_base,
- )
- self.register_buffer("freqs_cos", freqs_cos, persistent=False)
- self.register_buffer("freqs_sin", freqs_sin, persistent=False)
- if self.params.use_hf_rope:
- self.apply_rotary_emb = hf_apply_rotary_emb
- else:
- self.apply_rotary_emb = RotaryEmbedding()
-
- def forward(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- freqs_cos: torch.Tensor,
- freqs_sin: torch.Tensor,
- ):
- return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
-
- def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
- """
- Get the precomputed frequencies for the given input position and sequence length.
-
- Args:
- input_pos (torch.Tensor): The input position tensor.
- seq_len (int): The sequence length.
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
- """
- if self.params.use_kv_cache:
- assert (
- input_pos is not None
- ), "input_pos must be provided when use_kv_cache is True"
-
- if self.params.enable_dynamic_shape:
- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
- input_pos_item = input_pos[-1].item()
- torch._check_is_size(input_pos_item)
- torch._check(input_pos_item < self.params.max_seq_len)
- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
- freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
- # pyre-ignore: Incompatible parameter type [6]
- freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
- else:
- # When not using dynamic shape, use of the .item results in
- # symints, due to querying the data from tensor.
- # this path avoids that for mps backend, although probably mps backend
- # can support dynamic shape?
- freqs_cos = self.freqs_cos[input_pos]
- freqs_sin = self.freqs_sin[input_pos]
-
- else:
- assert input_pos is None, "input_pos is unused when use_kv_cache is False"
- freqs_cos = self.freqs_cos[:seq_len]
- freqs_sin = self.freqs_sin[:seq_len]
- return freqs_cos, freqs_sin
-
-
-class KVCache(nn.Module):
- def __init__(
- self,
- max_batch_size: int,
- max_seq_length: int,
- n_heads: int,
- head_dim: int,
- enable_dynamic_shape: bool,
- dtype=torch.float32,
- ):
- super().__init__()
- self.max_seq_length = max_seq_length
- cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
-
- self.max_batch_size = max_batch_size
- self.n_heads = n_heads
- self.head_dim = head_dim
- self.enable_dynamic_shape = enable_dynamic_shape
- self.register_buffer(
- "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
- )
- self.register_buffer(
- "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
- )
-
- def update(
- self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # input_pos: [S], k_val: [B, H, S, D]
- if self.enable_dynamic_shape:
- start_pos = input_pos[0].item()
- torch._check_is_size(start_pos)
- torch._check(start_pos < self.max_seq_length)
- dim_to_slice = 2
- seq_length = k_val.size(dim_to_slice)
- # Replace the entry in the cache for this token
- # The following lines are equivalent to:
- # cache_k[:bsz, start_pos : start_pos + seqlen] = xk
- # cache_v[:bsz, start_pos : start_pos + seqlen] = xv
- # when dim_to_slice is 1
- # We use .narrow() here to make the compiler happy
- # pyre-ignore: Incompatible parameter type [6]
- narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
- # pyre-ignore: Incompatible parameter type [6]
- narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
-
- narrowed_k.copy_(k_val)
- narrowed_v.copy_(v_val)
- return self.k_cache, self.v_cache
- else:
- k_out = self.k_cache
- v_out = self.v_cache
- k_out[:, :, input_pos] = k_val
- v_out[:, :, input_pos] = v_val
-
- return k_out, v_out
-
-
-class SDPA(nn.Module):
- def __init__(
- self,
- dim: int,
- head_dim: int,
- n_rep: int,
- max_seq_len: int,
- enable_dynamic_shape: bool,
- ):
- super().__init__()
- self.dim = dim
- self.head_dim = head_dim
- self.n_rep = n_rep
- self.max_seq_len = max_seq_len
- self.enable_dynamic_shape = enable_dynamic_shape
-
- def forward(
- self,
- input_pos: torch.Tensor,
- q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
- k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
- v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
- bsz,
- seqlen,
- mask: torch.Tensor,
- ) -> torch.Tensor:
- if self.enable_dynamic_shape:
- start_pos = input_pos[-1].item()
- torch._check_is_size(start_pos)
- torch._check(start_pos < self.max_seq_len)
- seq_length = q.size(2)
- # pyre-ignore: Incompatible parameter type [6]
- attn_mask = mask.narrow(0, start_pos, seq_length)
- else:
- attn_mask = mask[None, None, input_pos]
-
- # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
- # can natively support GQA now. But needs enable_gqa=True
- k = k.repeat_interleave(self.n_rep, dim=1)
- v = v.repeat_interleave(self.n_rep, dim=1)
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
-
- return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-
-
-class Attention(nn.Module):
- def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
- super().__init__()
- self.use_kv_cache = args.use_kv_cache
- self.n_heads = args.n_heads
- self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
- assert self.n_heads % self.n_kv_heads == 0
- model_parallel_size = 1
- self.n_local_heads = self.n_heads // model_parallel_size
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
- self.head_dim = args.head_dim
- self.max_batch_size = args.max_batch_size
- self.max_seq_len = args.max_seq_len
- self.dim = args.dim
- self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
- self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
- self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
- self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
-
- self.layer_id = layer_id
-
- self.rope = rope
-
- causal_mask = torch.tril(
- torch.ones(
- self.max_seq_len,
- self.max_seq_len,
- dtype=torch.bool,
- device="cpu",
- )
- )
- self.register_buffer("mask", causal_mask, persistent=False)
-
- if self.use_kv_cache:
- self.kv_cache = KVCache(
- args.max_batch_size,
- args.max_seq_len,
- self.n_kv_heads,
- self.head_dim,
- args.enable_dynamic_shape,
- )
- self.SDPA = SDPA(
- dim=self.n_local_heads * self.head_dim,
- head_dim=self.head_dim,
- n_rep=self.n_rep,
- max_seq_len=self.max_seq_len,
- enable_dynamic_shape=args.enable_dynamic_shape,
- )
-
- def forward(
- self,
- x: torch.Tensor,
- freqs_cos: torch.Tensor,
- freqs_sin: torch.Tensor,
- input_pos: Optional[torch.Tensor] = None,
- ):
- bsz, seqlen, _ = x.shape
-
- # QKV
- q, k, v = self.wq(x), self.wk(x), self.wv(x)
- # We need view_copy elimination
- q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
- k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
- v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
-
- # RoPE relative positional embeddings
- q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
-
- q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
- k = k.transpose(1, 2)
- v = v.transpose(1, 2)
-
- if self.use_kv_cache:
- assert input_pos is not None
- k, v = self.kv_cache.update(input_pos, k, v)
- output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
- return self.wo(output)
-
- # grouped multiquery attention: expand out keys and values
- k = k.repeat_interleave(self.n_rep, dim=1)
- v = v.repeat_interleave(self.n_rep, dim=1)
-
- assert hasattr(self, "mask")
-
- mask = self.mask[:seqlen, :seqlen]
-
- output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
-
- output = self.wo(output)
-
- return output
-
-
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -490,7 +134,13 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
- self.attention = Attention(args, layer_id, rope)
+ if args.attention_type not in ATTENTION_REGISTRY:
+ raise ValueError(
+ f"Unknown attention type: {args.attention_type}. "
+ f"Available: {list(ATTENTION_REGISTRY.keys())}"
+ )
+ cls = ATTENTION_REGISTRY[args.attention_type]
+ self.attention = cls(args, layer_id, rope)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
else:
@@ -500,7 +150,7 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
h = self.attention.forward(
- self.attention_norm(x), freqs_cos, freqs_sin, input_pos
+ self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos
)
h = x + h
@@ -528,6 +178,7 @@ def __init__(self, params: ModelArgs):
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
+ self.max_context_len = params.max_context_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map
diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py
index 9f7994916a..19c7ed0b31 100644
--- a/examples/models/llama/model.py
+++ b/examples/models/llama/model.py
@@ -15,8 +15,9 @@
get_checkpoint_dtype,
get_default_model_resource_dir,
)
+from executorch.examples.models.llama.llama_transformer import Transformer
-from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
+from executorch.examples.models.llama.model_args import ModelArgs
try:
from .fairseq2 import convert_to_llama_checkpoint
@@ -52,8 +53,13 @@ def __init__(self, **kwargs):
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
+ self.max_context_len = kwargs.get("max_context_len", 128)
self.args = kwargs.get("args", None)
+ assert (
+ self.max_context_len >= self.max_seq_len
+ ), f"max_context_len({self.max_context_len}) must be >= max_seq_len({self.max_seq_len})"
+
# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
device = "cpu"
@@ -136,6 +142,7 @@ def __init__(self, **kwargs):
model_args: ModelArgs = ModelArgs(
max_seq_len=self.max_seq_len,
+ max_context_len=self.max_context_len,
max_batch_size=1,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
@@ -219,7 +226,7 @@ def __init__(self, **kwargs):
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])
- assert self.args.max_seq_length == sink_size + window_size
+ assert self.args.max_context_length == sink_size + window_size
self.model_ = enable_attention_sink(
module=self.model_,
diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py
new file mode 100644
index 0000000000..e1c4edb8e9
--- /dev/null
+++ b/examples/models/llama/model_args.py
@@ -0,0 +1,81 @@
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+
+@dataclass
+class ModelArgs:
+ dim: int = 4096
+ n_layers: int = 32
+ n_heads: int = 32
+ n_kv_heads: Optional[int] = None
+ vocab_size: int = -1 # defined later by tokenizer
+ hidden_dim: Optional[int] = None
+ head_dim: Optional[int] = None # Optional customized head_dim
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
+ ffn_dim_multiplier: Optional[float] = None
+ norm_eps: float = 1e-5
+ max_batch_size: int = 32
+ max_seq_len: int = 2048
+ max_context_len: int = 2048
+ moe: bool = False # True to enable the MoE (Mixture of Experts)
+ num_experts: int = 8 # Number of experts
+ num_activated_experts: int = 2 # Number of experts to activate
+ attention_type: str = "mha" # Attention type, registered in attention.py
+ use_kv_cache: bool = False # Use key/value cache
+ use_sdpa_with_kv_cache_op: bool = (
+ False # Use custom sdpa op that updates kv cache in-place
+ )
+ # Generate logits for all inputs. When it's True, it would take big memory usage
+ # at runtime. Enable it only necessary (e.g., use perplexity tools that requires
+ # logits for all input tokens.)
+ generate_full_logits: bool = False
+ enable_dynamic_shape: bool = False # export model with dynamic shape support
+ # A dictionary mapping from pruned token-id to original token-id
+ input_prune_map: Optional[Dict[int, int]] = None
+ # A dictionary mapping from pruned token-id to original token-id
+ output_prune_map: Optional[Dict[int, int]] = None
+ use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
+ rope_theta: Optional[float] = (
+ None # The official name to override self.rope_freq_base.
+ )
+ rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
+ use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
+ rope_scale_factor: int = 8
+ # Additional Model Metadata needed at runtime
+ bos_idx: int = 1
+ eos_idx: int = 3
+ bos_count: int = -1 # i.e., a single EOS is used as BOS
+ eos_count: int = 2
+
+ quantization_args: Optional[dict] = None
+ lora_args: Optional[dict] = None
+
+ def __post_init__(self):
+ if self.n_kv_heads is None:
+ self.n_kv_heads = self.n_heads
+
+ # rope_theta overrides rope_freq_base since it's the official name.
+ if self.rope_theta is not None:
+ self.rope_freq_base = self.rope_theta
+
+ if self.use_sdpa_with_kv_cache_op:
+ assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"
+
+ if self.hidden_dim is None:
+ # If hidden_dim is not explicitly set in the ModelArgs,
+ # then calculate implicitly based on dim and also multiple of `args.multiple_of`
+ multiple_of = self.multiple_of
+ hidden_dim = 4 * self.dim
+ hidden_dim = int(2 * hidden_dim / 3)
+ if self.ffn_dim_multiplier is not None:
+ hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
+
+ def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+ self.hidden_dim = find_multiple(hidden_dim, multiple_of)
+
+ if self.head_dim is None:
+ self.head_dim = self.dim // self.n_heads
diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py
index cd3ddb0d3b..01352f404d 100644
--- a/examples/models/llama/rope.py
+++ b/examples/models/llama/rope.py
@@ -8,9 +8,11 @@
# Different RoPE implementations
import math
+from functools import partial
from typing import Optional, Tuple
import torch
+from executorch.examples.models.llama.model_args import ModelArgs
# ======================== Stock Implementation ========================
@@ -205,3 +207,80 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1):
sin = sin.unsqueeze(unsqueeze_dim)
k_embed = (k * cos) + (rotate_half(k) * sin)
return k_embed
+
+
+class Rope(torch.nn.Module):
+ def __init__(self, params: ModelArgs):
+ super().__init__()
+ self.params = params
+ if self.params.use_hf_rope:
+ self.precompute_freqs_cis = hf_precompute_freqs_cis
+ else:
+ self.precompute_freqs_cis = partial(
+ precompute_freqs_cis,
+ use_scaled=self.params.use_scaled_rope,
+ scale_factor=self.params.rope_scale_factor,
+ )
+ freqs_cos, freqs_sin = self.precompute_freqs_cis(
+ self.params.head_dim,
+ (
+ self.params.max_context_len # Normal llama2.
+ if self.params.ffn_dim_multiplier is None
+ else self.params.max_context_len * 2 # Sharded checkpoint.
+ ),
+ self.params.rope_freq_base,
+ )
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
+ if self.params.use_hf_rope:
+ self.apply_rotary_emb = hf_apply_rotary_emb
+ else:
+ self.apply_rotary_emb = RotaryEmbedding()
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
+
+ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
+ """
+ Get the precomputed frequencies for the given input position and sequence length.
+
+ Args:
+ input_pos (torch.Tensor): The input position tensor.
+ seq_len (int): The sequence length.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
+ """
+ if self.params.use_kv_cache:
+ assert (
+ input_pos is not None
+ ), "input_pos must be provided when use_kv_cache is True"
+
+ if self.params.enable_dynamic_shape:
+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
+ input_pos_item = input_pos[-1].item()
+ torch._check_is_size(input_pos_item)
+ torch._check(input_pos_item < self.params.max_context_len)
+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
+ freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
+ # pyre-ignore: Incompatible parameter type [6]
+ freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
+ else:
+ # When not using dynamic shape, use of the .item results in
+ # symints, due to querying the data from tensor.
+ # this path avoids that for mps backend, although probably mps backend
+ # can support dynamic shape?
+ freqs_cos = self.freqs_cos[input_pos]
+ freqs_sin = self.freqs_sin[input_pos]
+
+ else:
+ assert input_pos is None, "input_pos is unused when use_kv_cache is False"
+ freqs_cos = self.freqs_cos[:seq_len]
+ freqs_sin = self.freqs_sin[:seq_len]
+ return freqs_cos, freqs_sin
diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py
index 7dc9003f13..d5f065550d 100644
--- a/examples/models/llama/source_transformation/attention.py
+++ b/examples/models/llama/source_transformation/attention.py
@@ -12,7 +12,7 @@
from typing import List, Optional, Tuple
import torch
-from executorch.examples.models.llama.llama_transformer import Attention
+from executorch.examples.models.llama.attention import Attention
from torch import nn
@@ -32,7 +32,7 @@ class KVCacheSHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
- max_seq_length: int,
+ max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
@@ -40,7 +40,7 @@ def __init__(
super().__init__()
# a buffer per head
- cache_shape = (max_batch_size, max_seq_length, head_dim)
+ cache_shape = (max_batch_size, max_context_length, head_dim)
for i in range(n_heads):
self.register_buffer(
f"past_k_caches_{i}",
@@ -79,7 +79,7 @@ class SDPASHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
- max_seq_length: int,
+ max_context_length: int,
n_heads: int,
n_rep: int,
head_dim: int,
@@ -90,7 +90,7 @@ def __init__(
self.n_rep = n_rep
self.dim = dim
self.kv_cache = KVCacheSHA(
- max_batch_size, max_seq_length, n_heads // n_rep, head_dim
+ max_batch_size, max_context_length, n_heads // n_rep, head_dim
)
self.scale_factor = math.sqrt(head_dim)
@@ -134,11 +134,11 @@ def __init__(self, attention_mha: nn.Module):
self.n_rep = self.n_heads // self.n_kv_heads
self.dim = attention_mha.dim
self.max_batch_size = attention_mha.max_batch_size
- self.max_seq_len = attention_mha.max_seq_len
+ self.max_context_len = attention_mha.max_context_len
self.head_dim = attention_mha.dim // self.n_heads
self.SDPA = SDPASHA(
self.max_batch_size,
- self.max_seq_len,
+ self.max_context_len,
self.n_heads,
self.n_rep,
self.head_dim,
@@ -184,8 +184,8 @@ def __init__(self, attention_mha: nn.Module):
causal_mask = torch.tril(
torch.ones(
- self.max_seq_len,
- self.max_seq_len,
+ self.max_context_len,
+ self.max_context_len,
dtype=torch.bool,
device="cpu",
)
diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py
index 5b3bfba9ad..22bd8a3e22 100644
--- a/examples/models/llama/source_transformation/attention_sink.py
+++ b/examples/models/llama/source_transformation/attention_sink.py
@@ -12,15 +12,12 @@
import torch
-from executorch.examples.models.llama.llama_transformer import (
- Attention,
- KVCache,
- ModelArgs,
- Rope,
-)
+from executorch.examples.models.llama.attention import AttentionMHA, KVCache
+from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import (
apply_rotary_emb_to_k,
hf_apply_rotary_emb_to_k,
+ Rope,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
@@ -44,8 +41,8 @@ def __init__(
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
- self.max_seq_length = window_size + sink_size
- assert self.max_seq_length == self.params.max_seq_len
+ self.max_context_length = window_size + sink_size
+ assert self.max_context_length == self.params.max_context_len
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0
@@ -54,11 +51,14 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
- if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
+ if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
- input_pos_item + self.position_shift - self.max_seq_length + seq_len,
+ input_pos_item
+ + self.position_shift
+ - self.max_context_length
+ + seq_len,
self.eviction_batch_size,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
@@ -121,7 +121,7 @@ def __init__(
):
super().__init__(
max_batch_size=max_batch_size,
- max_seq_length=window_size + sink_size,
+ max_context_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
enable_dynamic_shape=enable_dynamic_shape,
@@ -148,11 +148,14 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
- if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
+ if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
- input_pos_item + self.position_shift - self.max_seq_length + seq_len,
+ input_pos_item
+ + self.position_shift
+ - self.max_context_length
+ + seq_len,
self.eviction_batch_size,
)
num_to_keep = (
@@ -260,7 +263,7 @@ def _replace_attention(
eviction_batch_size=eviction_batch_size,
)
- if isinstance(child_module, Attention):
+ if isinstance(child_module, AttentionMHA):
kv_cache = child_module.kv_cache
kv_cache_with_attention_sink = KVCacheWithAttentionSink(
n_heads=kv_cache.n_heads,
diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py
index 90ec9879e5..023fc6800f 100644
--- a/examples/models/llama/source_transformation/quantized_kv_cache.py
+++ b/examples/models/llama/source_transformation/quantized_kv_cache.py
@@ -10,7 +10,7 @@
import torch
import torch.nn as nn
-from executorch.examples.models.llama.llama_transformer import KVCache
+from executorch.examples.models.llama.attention import KVCache
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
@@ -33,7 +33,7 @@ class QuantizedKVCache(nn.Module):
def __init__(
self,
max_batch_size,
- max_seq_length,
+ max_context_length,
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
@@ -52,8 +52,8 @@ def __init__(
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
- cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
- scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
+ cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
+ scale_shape = (max_batch_size, max_context_length, n_heads, 1)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
@@ -161,13 +161,15 @@ def from_float(
cache_type: QuantizedCacheType,
use_custom_update_cache_op: bool = False,
):
- max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
+ max_batch_size, n_heads, max_context_length, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
- max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
+ max_batch_size, max_context_length, n_heads, head_dim = (
+ kv_cache.k_cache.shape
+ )
return cls(
max_batch_size,
- max_seq_length,
+ max_context_length,
n_heads,
head_dim,
cache_type,
@@ -226,14 +228,14 @@ class CustomKVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
- max_seq_length: int,
+ max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
- self.max_seq_length = max_seq_length
- cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
+ self.max_context_length = max_context_length
+ cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
self.max_batch_size = max_batch_size
self.n_heads = n_heads
@@ -275,13 +277,13 @@ def replace_kv_cache_with_custom_kv_cache(module):
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
- max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
+ max_batch_size, n_heads, max_context_length, head_dim = cache_shape
setattr(
module,
name,
CustomKVCache(
max_batch_size,
- max_seq_length,
+ max_context_length,
n_heads,
head_dim,
dtype=cache_dtype,
diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py
index 6a54d6a119..1bb7d27754 100644
--- a/examples/models/llama/source_transformation/sdpa.py
+++ b/examples/models/llama/source_transformation/sdpa.py
@@ -13,7 +13,7 @@
import torch
-from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
+from executorch.examples.models.llama.attention import KVCache, SDPA
class SDPACustom(torch.nn.Module):
@@ -268,14 +268,14 @@ class KVCacheCoreML(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
- max_seq_length: int,
+ max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
- self.max_seq_length = max_seq_length
- cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+ self.max_context_length = max_context_length
+ cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)
self.max_batch_size = max_batch_size
self.n_heads = n_heads
@@ -303,7 +303,7 @@ def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module):
name,
KVCacheCoreML(
child.max_batch_size,
- child.max_seq_length,
+ child.max_context_length,
child.n_heads,
child.head_dim,
child.k_cache.dtype,
@@ -318,13 +318,13 @@ class KVCacheSimple(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
- max_seq_length: int,
+ max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()
- cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
+ cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
self.register_buffer(
"past_k_caches",
torch.zeros(cache_shape, dtype=dtype, device="cpu"),
@@ -358,7 +358,7 @@ def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module):
name,
KVCacheSimple(
child.max_batch_size,
- child.max_seq_length,
+ child.max_context_length,
child.n_heads,
child.head_dim,
child.k_cache.dtype,
@@ -373,9 +373,9 @@ def replace_causal_mask(module: torch.nn.Module):
for buffer_fqn_name, buffer in module.named_buffers():
buffer_name = buffer_fqn_name.split(".")[-1]
if buffer_name == "mask":
- max_seq_len = buffer.shape[-1]
+ max_context_len = buffer.shape[-1]
mask = torch.full(
- (max_seq_len, max_seq_len),
+ (max_context_len, max_context_len),
float("-inf"),
device="cpu",
)
diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py
index 4dd522dff2..fc882ebf4a 100644
--- a/examples/models/llama/source_transformation/test_attention_sink.py
+++ b/examples/models/llama/source_transformation/test_attention_sink.py
@@ -7,7 +7,7 @@
import unittest
import torch
-from executorch.examples.models.llama.llama_transformer import ModelArgs
+from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.source_transformation.attention_sink import (
KVCacheWithAttentionSink,
@@ -29,7 +29,7 @@ def _init_rope(self, params: ModelArgs, eviction_batch_size: int):
def setUp(self):
torch.manual_seed(42)
self.params = ModelArgs(
- use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256
+ use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256
)
self.rope_with_attention_sink = self._init_rope(
params=self.params, eviction_batch_size=1
@@ -135,7 +135,7 @@ def _init_cache(self, sink_size, eviction_batch_size):
self.params = ModelArgs(
use_kv_cache=True,
enable_dynamic_shape=True,
- max_seq_len=self.window_size + sink_size,
+ max_context_len=self.window_size + sink_size,
)
self.rope_with_attention_sink = RopeWithAttentionSink(
params=self.params,
diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py
index 67ebbc7b3f..4252518a4e 100644
--- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py
+++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py
@@ -8,7 +8,7 @@
import torch
-from executorch.examples.models.llama.llama_transformer import KVCache
+from executorch.examples.models.llama.attention import KVCache
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
QuantizedCacheType,
@@ -20,7 +20,7 @@ class QuantizedKVCacheTest(unittest.TestCase):
def _init_cache(self):
self.kv_cache = KVCache(
self.max_batch_size,
- self.max_seq_len,
+ self.max_context_len,
self.n_kv_heads,
self.head_dim,
self.enable_dynamic_shape,
@@ -36,7 +36,7 @@ def _init_kv(self):
def setUp(self):
torch.manual_seed(42)
self.max_batch_size = 1
- self.max_seq_len = 5
+ self.max_context_len = 5
self.n_kv_heads = 8
self.head_dim = 17
self.enable_dynamic_shape = False
diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py
index 0081c5072c..35c88e10b6 100644
--- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py
+++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py
@@ -8,7 +8,7 @@
import torch
-from executorch.examples.models.llama.llama_transformer import KVCache
+from executorch.examples.models.llama.attention import KVCache
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
CustomKVCache,
@@ -23,7 +23,7 @@ class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
def _init_cache(self):
self.kv_cache = KVCache(
self.max_batch_size,
- self.max_seq_len,
+ self.max_context_len,
self.n_kv_heads,
self.head_dim,
self.enable_dynamic_shape,
@@ -40,7 +40,7 @@ def _init_cache(self):
# as a sequence of token positions
self.custom_kv_cache = CustomKVCache(
self.max_batch_size,
- self.max_seq_len,
+ self.max_context_len,
self.n_kv_heads,
self.head_dim,
dtype=self.dtype,
@@ -57,7 +57,7 @@ def _init_kv(self):
def setUp(self):
torch.manual_seed(42)
self.max_batch_size = 1
- self.max_seq_len = 5
+ self.max_context_len = 5
self.n_kv_heads = 4
self.n_heads = 8
self.head_dim = 17
diff --git a/examples/models/llama/tests/test_pre_quantization_transforms.py b/examples/models/llama/tests/test_pre_quantization_transforms.py
index dc7c640dba..345f3fad9b 100644
--- a/examples/models/llama/tests/test_pre_quantization_transforms.py
+++ b/examples/models/llama/tests/test_pre_quantization_transforms.py
@@ -7,7 +7,8 @@
import unittest
import torch
-from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
+from executorch.examples.models.llama.llama_transformer import Transformer
+from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.source_transformation.pre_quantization import (
sanitize_checkpoint_from_pre_quantization,
transform_embedding_for_pre_quantization,
diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py
index 4088165c71..d60bc30b7d 100644
--- a/examples/models/llama/tests/test_simple_sdpa.py
+++ b/examples/models/llama/tests/test_simple_sdpa.py
@@ -7,7 +7,7 @@
import unittest
import torch
-from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
+from executorch.examples.models.llama.attention import KVCache, SDPA
from executorch.examples.models.llama.source_transformation.sdpa import SDPASimple
@@ -15,7 +15,7 @@ class SDPATest(unittest.TestCase):
def test_simple_sdpa(self):
# Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py
max_batch_size = 1
- max_seq_length = 128
+ max_context_length = 128
n_heads = 8
head_dim = 8
dim = 64
@@ -25,7 +25,7 @@ def test_simple_sdpa(self):
n_local_heads = n_heads
kv_cache = KVCache(
max_batch_size=max_batch_size,
- max_seq_length=max_seq_length,
+ max_context_length=max_context_length,
n_heads=n_heads,
head_dim=head_dim,
enable_dynamic_shape=False,
@@ -34,14 +34,14 @@ def test_simple_sdpa(self):
dim=dim,
head_dim=head_dim,
n_rep=n_rep,
- max_seq_len=max_seq_length,
+ max_context_len=max_context_length,
enable_dynamic_shape=False,
)
input_pos = torch.tensor([0])
query = torch.randn(1, 1, n_local_heads, head_dim)
key = torch.randn(1, 1, n_local_heads, head_dim)
value = torch.randn(1, 1, n_local_heads, head_dim)
- mask = torch.randn(max_seq_length, max_seq_length)
+ mask = torch.randn(max_context_length, max_context_length)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
diff --git a/examples/models/llama3_2_vision/cross_attention/cross_attention_mask_test.cpp b/examples/models/llama3_2_vision/cross_attention/cross_attention_mask_test.cpp
index e2256b14a8..8d144b4f72 100644
--- a/examples/models/llama3_2_vision/cross_attention/cross_attention_mask_test.cpp
+++ b/examples/models/llama3_2_vision/cross_attention/cross_attention_mask_test.cpp
@@ -11,9 +11,9 @@
#include
using namespace ::testing;
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
-using exec_aten::TensorImpl;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
+using executorch::aten::TensorImpl;
TEST(CrossAttentxnMaskTest, TestCrossAttentionMask) {
std::vector tokens = {
diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py
index 68a9e59e0c..304b49759f 100644
--- a/examples/models/llava/model.py
+++ b/examples/models/llava/model.py
@@ -12,7 +12,8 @@
import requests
import torch
-from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
+from executorch.examples.models.llama.llama_transformer import Transformer
+from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h
index b4b1ef420c..c48fe2b1fe 100644
--- a/examples/models/llava/runner/llava_image_prefiller.h
+++ b/examples/models/llava/runner/llava_image_prefiller.h
@@ -26,7 +26,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
* @param start_pos The starting position in KV cache of the input in the LLM
* @return logits of the image prefill.
*/
- inline ::executorch::runtime::Result prefill(
+ inline ::executorch::runtime::Result prefill(
::executorch::extension::llm::Image& image,
int64_t& start_pos) override {
auto image_tensor = executorch::extension::from_blob(
diff --git a/examples/models/llava/runner/llava_text_decoder_runner.h b/examples/models/llava/runner/llava_text_decoder_runner.h
index 236d412910..4c7809361b 100644
--- a/examples/models/llava/runner/llava_text_decoder_runner.h
+++ b/examples/models/llava/runner/llava_text_decoder_runner.h
@@ -23,7 +23,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
float temperature)
: TextDecoderRunner(module, true, vocab_size, temperature){};
- inline executorch::runtime::Result step(
+ inline executorch::runtime::Result step(
executorch::extension::TensorPtr& tokens,
executorch::extension::TensorPtr& start_pos) override {
// run token embedding
diff --git a/examples/models/phi-3-mini/runner.cpp b/examples/models/phi-3-mini/runner.cpp
index ca299d3b11..1163a35d66 100644
--- a/examples/models/phi-3-mini/runner.cpp
+++ b/examples/models/phi-3-mini/runner.cpp
@@ -73,14 +73,15 @@ void Runner::generate(const std::string& prompt, std::size_t max_seq_len) {
std::cout << std::endl;
}
-uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) {
+uint64_t Runner::logits_to_token(
+ const executorch::aten::Tensor& logits_tensor) {
return sampler_->sample(logits_tensor.data_ptr());
}
uint64_t Runner::prefill(std::vector& tokens) {
auto result = module_->forward(executorch::extension::from_blob(
tokens.data(),
- {1, static_cast(tokens.size())},
+ {1, static_cast(tokens.size())},
ScalarType::Long));
ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens");
diff --git a/examples/models/phi-3-mini/runner.h b/examples/models/phi-3-mini/runner.h
index 9b24f97170..2048acdab2 100644
--- a/examples/models/phi-3-mini/runner.h
+++ b/examples/models/phi-3-mini/runner.h
@@ -38,7 +38,7 @@ class Runner {
void generate(const std::string& prompt, std::size_t max_seq_len);
private:
- uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor);
+ uint64_t logits_to_token(const executorch::aten::Tensor& logits_tensor);
uint64_t prefill(std::vector& tokens);
uint64_t run_model_step(uint64_t token);
diff --git a/examples/portable/custom_ops/custom_ops_1_out.cpp b/examples/portable/custom_ops/custom_ops_1_out.cpp
index 660107f275..e26dfefe23 100644
--- a/examples/portable/custom_ops/custom_ops_1_out.cpp
+++ b/examples/portable/custom_ops/custom_ops_1_out.cpp
@@ -11,8 +11,8 @@
namespace custom {
namespace native {
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
namespace {
diff --git a/examples/portable/custom_ops/custom_ops_2_out.cpp b/examples/portable/custom_ops/custom_ops_2_out.cpp
index 69436750cc..138a8eeed8 100644
--- a/examples/portable/custom_ops/custom_ops_2_out.cpp
+++ b/examples/portable/custom_ops/custom_ops_2_out.cpp
@@ -11,8 +11,8 @@
namespace custom {
namespace native {
-using exec_aten::ScalarType;
-using exec_aten::Tensor;
+using executorch::aten::ScalarType;
+using executorch::aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
namespace {
diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp
index 65ba762743..f7702fae3d 100644
--- a/examples/portable/executor_runner/executor_runner.cpp
+++ b/examples/portable/executor_runner/executor_runner.cpp
@@ -1,5 +1,6 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
+ * Copyright 2024-2025 Arm Limited and/or its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
@@ -25,10 +26,14 @@
#include
#include
#include
+#include
#include
#include
#include
#include
+#ifdef ET_EVENT_TRACER_ENABLED
+#include
+#endif // ET_EVENT_TRACER_ENABLED
static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB
@@ -38,10 +43,15 @@ DEFINE_string(
model_path,
"model.pte",
"Model serialized in flatbuffer format.");
+DEFINE_uint32(num_executions, 1, "Number of times to run the model.");
+#ifdef ET_EVENT_TRACER_ENABLED
+DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path.");
+#endif // ET_EVENT_TRACER_ENABLED
using executorch::extension::FileDataLoader;
using executorch::runtime::Error;
using executorch::runtime::EValue;
+using executorch::runtime::EventTracer;
using executorch::runtime::HierarchicalAllocator;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::MemoryManager;
@@ -51,6 +61,56 @@ using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::Span;
+/// Helper to manage resources for ETDump generation
+class EventTraceManager {
+ public:
+ EventTraceManager() : event_tracer_ptr_(nullptr) {
+#ifdef ET_EVENT_TRACER_ENABLED
+ event_tracer_ptr_ = std::make_shared();
+#endif // ET_EVENT_TRACER_ENABLED
+ }
+
+ EventTracer* get_event_tracer() const {
+ return event_tracer_ptr_.get();
+ };
+
+ Error write_etdump_to_file() const {
+ EventTracer* const event_tracer_ptr = get_event_tracer();
+ if (!event_tracer_ptr) {
+ return Error::NotSupported;
+ }
+
+#ifdef ET_EVENT_TRACER_ENABLED
+ executorch::etdump::ETDumpGen* const etdump_ptr =
+ static_cast(event_tracer_ptr);
+
+ const char* filename = FLAGS_etdump_path.c_str();
+
+ std::unique_ptr etdump_file(
+ fopen(filename, "w+"), fclose);
+ if (!etdump_file) {
+ ET_LOG(Error, "Failed to open ETDump file at %s.", filename);
+ return Error::AccessFailed;
+ }
+
+ executorch::etdump::ETDumpResult result = etdump_ptr->get_etdump_data();
+ if (result.buf != nullptr && result.size > 0) {
+ fwrite((uint8_t*)result.buf, 1, result.size, etdump_file.get());
+ free(result.buf);
+ ET_LOG(Info, "ETDump written to file '%s'.", filename);
+ } else {
+ ET_LOG(Error, "No ETDump data available!");
+ return Error::NotFound;
+ }
+#endif // ET_EVENT_TRACER_ENABLED
+
+ return Error::Ok;
+ }
+
+ private:
+ std::shared_ptr event_tracer_ptr_;
+};
+
int main(int argc, char** argv) {
executorch::runtime::runtime_init();
@@ -158,8 +218,9 @@ int main(int argc, char** argv) {
// the method can mutate the memory-planned buffers, so the method should only
// be used by a single thread at at time, but it can be reused.
//
-
- Result method = program->load_method(method_name, &memory_manager);
+ EventTraceManager tracer;
+ Result method = program->load_method(
+ method_name, &memory_manager, tracer.get_event_tracer());
ET_CHECK_MSG(
method.ok(),
"Loading of method %s failed with status 0x%" PRIx32,
@@ -178,18 +239,23 @@ int main(int argc, char** argv) {
ET_LOG(Info, "Inputs prepared.");
// Run the model.
- Error status = method->execute();
- ET_CHECK_MSG(
- status == Error::Ok,
- "Execution of method %s failed with status 0x%" PRIx32,
- method_name,
- (uint32_t)status);
- ET_LOG(Info, "Model executed successfully.");
+ for (uint32_t i = 0; i < FLAGS_num_executions; i++) {
+ Error status = method->execute();
+ ET_CHECK_MSG(
+ status == Error::Ok,
+ "Execution of method %s failed with status 0x%" PRIx32,
+ method_name,
+ (uint32_t)status);
+ }
+ ET_LOG(
+ Info,
+ "Model executed successfully %" PRIu32 " time(s).",
+ FLAGS_num_executions);
// Print the outputs.
std::vector outputs(method->outputs_size());
ET_LOG(Info, "%zu outputs: ", outputs.size());
- status = method->get_outputs(outputs.data(), outputs.size());
+ Error status = method->get_outputs(outputs.data(), outputs.size());
ET_CHECK(status == Error::Ok);
// Print the first and last 100 elements of long lists of scalars.
std::cout << executorch::extension::evalue_edge_items(100);
@@ -197,5 +263,12 @@ int main(int argc, char** argv) {
std::cout << "Output " << i << ": " << outputs[i] << std::endl;
}
+ if (tracer.get_event_tracer()) {
+ // Dump ETDump data containing profiling/debugging data to file specified in
+ // command line flag.
+ Error status = tracer.write_etdump_to_file();
+ ET_CHECK_MSG(status == Error::Ok, "Failed to save ETDump file.");
+ }
+
return 0;
}
diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt
index a8e16bb5c9..55969f937e 100644
--- a/examples/qualcomm/CMakeLists.txt
+++ b/examples/qualcomm/CMakeLists.txt
@@ -84,11 +84,8 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
# build qnn_executor_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)
-# build qnn_llama_runner for llama2
-add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)
-
-# build qnn_llama_runner for llama3.2
-add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
+# build qnn_llama_runner for llama
+add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)
diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md
index 3d5eb42939..bdac58d2bf 100644
--- a/examples/qualcomm/README.md
+++ b/examples/qualcomm/README.md
@@ -4,10 +4,10 @@ This directory contains examples for some AI models.
We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:
-1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
+1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
- For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
+ For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Please refer to [Check context binary version](#check-context-binary-version) for tutorial on how to check the QNN Version for a context binary.
diff --git a/examples/qualcomm/oss_scripts/conv_former.py b/examples/qualcomm/oss_scripts/conv_former.py
new file mode 100644
index 0000000000..76131d659d
--- /dev/null
+++ b/examples/qualcomm/oss_scripts/conv_former.py
@@ -0,0 +1,139 @@
+# Copyright (c) Qualcomm Innovation Center, Inc.
+# All rights reserved
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import sys
+from multiprocessing.connection import Client
+
+import numpy as np
+import timm
+import torch
+from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
+from executorch.backends.qualcomm.utils.constants import (
+ QCOM_PASS_EXPAND_BROADCAST_SHAPE,
+)
+from executorch.examples.qualcomm.utils import (
+ build_executorch_binary,
+ get_imagenet_dataset,
+ make_output_dir,
+ parse_skip_delegation_node,
+ setup_common_args_and_variables,
+ SimpleADB,
+ topk_accuracy,
+)
+
+
+def main(args):
+ skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
+
+ # ensure the working directory exist.
+ os.makedirs(args.artifact, exist_ok=True)
+
+ if not args.compile_only and args.device is None:
+ raise RuntimeError(
+ "device serial is required if not compile only. "
+ "Please specify a device serial by -s/--device argument."
+ )
+
+ data_num = 100
+ if args.compile_only:
+ inputs = [(torch.rand(1, 3, 224, 224),)]
+ else:
+ inputs, targets, input_list = get_imagenet_dataset(
+ dataset_path=f"{args.dataset}",
+ data_size=data_num,
+ image_shape=(256, 256),
+ crop_size=224,
+ )
+
+ pte_filename = "conv_former"
+ model = timm.create_model("convformer_s18.sail_in1k", pretrained=True)
+
+ model = model.eval()
+
+ build_executorch_binary(
+ model,
+ inputs[0],
+ args.model,
+ f"{args.artifact}/{pte_filename}",
+ inputs,
+ quant_dtype=QuantDtype.use_8a8w,
+ custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
+ )
+
+ if args.compile_only:
+ sys.exit(0)
+
+ adb = SimpleADB(
+ qnn_sdk=os.getenv("QNN_SDK_ROOT"),
+ build_path=f"{args.build_folder}",
+ pte_path=f"{args.artifact}/{pte_filename}.pte",
+ workspace=f"/data/local/tmp/executorch/{pte_filename}",
+ device_id=args.device,
+ host_id=args.host,
+ soc_model=args.model,
+ shared_buffer=args.shared_buffer,
+ )
+ adb.push(inputs=inputs, input_list=input_list)
+ adb.execute()
+
+ # collect output data
+ output_data_folder = f"{args.artifact}/outputs"
+ make_output_dir(output_data_folder)
+
+ adb.pull(output_path=args.artifact)
+
+ # top-k analysis
+ predictions = []
+ for i in range(data_num):
+ predictions.append(
+ np.fromfile(
+ os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
+ )
+ )
+
+ k_val = [1, 5]
+ topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
+ if args.ip and args.port != -1:
+ with Client((args.ip, args.port)) as conn:
+ conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
+ else:
+ for i, k in enumerate(k_val):
+ print(f"top_{k}->{topk[i]}%")
+
+
+if __name__ == "__main__":
+ parser = setup_common_args_and_variables()
+ parser.add_argument(
+ "-a",
+ "--artifact",
+ help="path for storing generated artifacts by this example. Default ./conv_former",
+ default="./conv_former",
+ type=str,
+ )
+
+ parser.add_argument(
+ "-d",
+ "--dataset",
+ help=(
+ "path to the validation folder of ImageNet dataset. "
+ "e.g. --dataset imagenet-mini/val "
+ "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
+ ),
+ type=str,
+ required=True,
+ )
+
+ args = parser.parse_args()
+ try:
+ main(args)
+ except Exception as e:
+ if args.ip and args.port != -1:
+ with Client((args.ip, args.port)) as conn:
+ conn.send(json.dumps({"Error": str(e)}))
+ else:
+ raise Exception(e)
diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py
index 0e2c695ab3..f0d2f4c3f0 100644
--- a/examples/qualcomm/oss_scripts/fastvit.py
+++ b/examples/qualcomm/oss_scripts/fastvit.py
@@ -10,6 +10,9 @@
import numpy as np
import torch
+from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
+ ExpandBroadcastTensorShape,
+)
from executorch.backends.qualcomm.quantizer.annotators import (
QuantizationConfig,
QuantizationSpec,
@@ -23,10 +26,11 @@
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
-from executorch.backends.qualcomm.utils.constants import (
- QCOM_PASS_EXPAND_BROADCAST_SHAPE,
+from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY
+from executorch.backends.qualcomm.utils.utils import (
+ convert_linear_to_conv2d,
+ get_capture_program_passes,
)
-from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
from executorch.examples.qualcomm.utils import (
build_executorch_binary,
get_imagenet_dataset,
@@ -111,6 +115,8 @@ def main(args):
bias=q_config.bias,
)
# lower to QNN
+ passes_job = get_capture_program_passes()
+ passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
build_executorch_binary(
convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)),
inputs[0],
@@ -121,7 +127,7 @@ def main(args):
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_8a8w,
custom_quantizer=quantizer,
- custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
+ passes_job=passes_job,
shared_buffer=args.shared_buffer,
)
diff --git a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt
similarity index 60%
rename from examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt
rename to examples/qualcomm/oss_scripts/llama/CMakeLists.txt
index 93b35a697c..4059ae7151 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt
+++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt
@@ -18,38 +18,37 @@ target_link_libraries(
)
target_link_options_shared_lib(custom_ops)
-# preprocess qnn runner src files for llama3.2
-set(_llama3_2_runner__srcs ${_llama_runner__srcs})
-list(TRANSFORM _llama3_2_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
-list(FILTER _llama3_2_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
+# preprocess qnn runner src files for llama
+set(_llama_runner__srcs ${_llama_runner__srcs})
+list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
+list(FILTER _llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
list(
PREPEND
- _llama3_2_runner__srcs
- ${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
+ _llama_runner__srcs
+ ${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
- ${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
- ${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
+ ${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
+ ${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
)
-list(
- APPEND _llama3_2_runner__srcs
- ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
-)
list(
APPEND
- _llama3_2_runner__srcs
+ _llama_runner__srcs
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp
)
-# build qnn llama3.2 1b runner
-add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
+# build qnn llama runner
+add_executable(qnn_llama_runner ${_llama_runner__srcs})
target_include_directories(
- qnn_llama3_2_runner PUBLIC ${_common_include_directories}
+ qnn_llama_runner PUBLIC ${_common_include_directories}
)
+target_link_options_shared_lib(quantized_ops_lib)
+
target_link_libraries(
- qnn_llama3_2_runner
+ qnn_llama_runner
qnn_executorch_backend
executorch_core
extension_data_loader
@@ -58,10 +57,12 @@ target_link_libraries(
gflags
re2::re2
custom_ops
+ quantized_ops_lib
+ quantized_kernels
)
target_compile_options(
- qnn_llama3_2_runner PUBLIC ${_common_compile_options}
+ qnn_llama_runner PUBLIC ${_common_compile_options}
)
set_target_properties(
- qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
+ qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
)
diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md
new file mode 100644
index 0000000000..79c20180d6
--- /dev/null
+++ b/examples/qualcomm/oss_scripts/llama/README.md
@@ -0,0 +1,70 @@
+# Summary
+
+## Overview
+This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
+ 1. LLAMA2 Stories 110M
+ 2. LLAMA3.2 1B
+ 3. LLAMA3.2 3B (WIP)
+We offer the following modes to execute the model:
+
+Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt).
+
+KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
+
+Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
+
+
+## Instructions
+### Note
+1. For hybrid mode, the export time will be longer and can take up to 1-4 hours to complete, depending on the specific model users are exporting.
+2. When exporting a hybrid mode model, memory consumption will be higher. Taking LLAMA3.2 1B as an example, please ensure the device has at least 80 GB of memory and swap space.
+
+
+### Step 1: Setup
+1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
+2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
+
+### Step 2: Prepare Model
+
+#### LLAMA2
+Download and prepare stories110M model
+
+```bash
+# tokenizer.model & stories110M.pt:
+wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
+wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
+
+# tokenizer.bin:
+python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
+
+# params.json:
+echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
+```
+
+#### LLAMA3.2
+Follow the [instructions](https://www.llama.com/) to download models.
+At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.
+
+
+### Step3: Run default examples using hybrid mode.
+#### LLAMA2
+```bash
+python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time"
+```
+
+#### LLAMA3.2
+Default example using hybrid mode.
+```bash
+python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
+```
+
+### Additional Configs when running the script
+If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
+```bash
+python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only
+```
+
+On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
+```bash
+python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
+```
\ No newline at end of file
diff --git a/examples/qualcomm/oss_scripts/llama2/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS
similarity index 57%
rename from examples/qualcomm/oss_scripts/llama2/TARGETS
rename to examples/qualcomm/oss_scripts/llama/TARGETS
index b0f5ea7f64..419316acf0 100644
--- a/examples/qualcomm/oss_scripts/llama2/TARGETS
+++ b/examples/qualcomm/oss_scripts/llama/TARGETS
@@ -5,7 +5,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
oncall("executorch")
-
python_library(
name = "static_llama",
srcs = [
@@ -16,12 +15,33 @@ python_library(
],
)
+python_library(
+ name = "llama_lib",
+ srcs = ["llama.py"],
+ deps = [
+ "//caffe2:torch",
+ "//executorch/backends/qualcomm/partition:partition",
+ "//executorch/backends/qualcomm/quantizer:quantizer",
+ "//executorch/devtools:lib",
+ "//executorch/examples/models:models",
+ "//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
+ "//executorch/examples/qualcomm:utils",
+ "//executorch/extension/export_util:export_util",
+ "//executorch/extension/llm/custom_ops:model_sharding_py",
+ "//executorch/extension/llm/export:export_lib",
+ "//executorch/extension/pybindings:aten_lib",
+ ],
+)
+
python_binary(
name = "llama",
srcs = ["llama.py"],
- main_function = "executorch.examples.qualcomm.oss_scripts.llama2.llama.main",
+ main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
+ preload_deps = [
+ "//executorch/extension/llm/custom_ops:model_sharding_py",
+ ],
deps = [
- ":static_llama",
+ "//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
"//caffe2:torch",
"//executorch/extension/pybindings:aten_lib",
"//executorch/backends/qualcomm/partition:partition",
@@ -38,6 +58,8 @@ runtime.command_alias(
name = "llama_qnn",
env = {
"LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs)".format(get_qnn_library_verision()),
+ # Place holder to pass the QNN_SDK_ROOT check in executorch/examples/qualcomm/utils.py
+ "QNN_SDK_ROOT": "",
},
exe = ":llama",
)
diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py
similarity index 65%
rename from examples/qualcomm/oss_scripts/llama3_2/llama.py
rename to examples/qualcomm/oss_scripts/llama/llama.py
index a18690e941..e575a3f5c4 100755
--- a/examples/qualcomm/oss_scripts/llama3_2/llama.py
+++ b/examples/qualcomm/oss_scripts/llama/llama.py
@@ -14,20 +14,32 @@
import os
import sys
import time
+from collections import OrderedDict
from functools import partial
from multiprocessing.connection import Client
import torch
+from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.backends.qualcomm.quantizer.custom_annotation import (
+ annotate_linear_16a8w_in_affine_layer,
annotate_matmul_16a8w,
+ annotate_prefill_kv_output,
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
-from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
+
+from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
+ flatbuffer_to_option,
+ option_to_flatbuffer,
+)
+from executorch.backends.qualcomm.utils.constants import (
+ QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY,
+ QCOM_QUANTIZED_IO,
+)
from executorch.backends.qualcomm.utils.utils import (
capture_program,
convert_linear_to_conv2d,
@@ -35,10 +47,15 @@
generate_htp_compiler_spec,
generate_multi_graph_program,
generate_qnn_executorch_compiler_spec,
+ get_capture_program_passes,
get_soc_to_chipset_map,
update_spill_fill_size,
)
-from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import (
+from executorch.examples.models.llama.source_transformation.quantize import (
+ get_quant_embedding_transform,
+)
+from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
+from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
LlamaModel,
ModelArgs,
)
@@ -55,6 +72,9 @@
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.extension.llm.custom_ops import model_sharding
from executorch.extension.llm.export.builder import DType
+from executorch.extension.llm.tokenizer.tokenizer import (
+ Tokenizer as SentencePieceTokenizer,
+)
from executorch.extension.llm.tokenizer.utils import get_tokenizer
from torch.ao.quantization.observer import MinMaxObserver
@@ -66,74 +86,116 @@
logging.getLogger().setLevel(logging.INFO)
+def smart_mask_updator(atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches):
+ for i, k_cache in enumerate(k_caches):
+ k_cache[:, :, pos] = new_k_caches[i][:, :, 0]
+
+ for i, v_cache in enumerate(v_caches):
+ v_cache[:, pos, :] = new_v_caches[i]
+
+ atten_mask[0][pos] = 0
+ pos += 1
+ return (atten_mask, pos, k_caches, v_caches)
+
+
+def shift_pointer_updator(
+ atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
+):
+ k_caches = [
+ torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
+ for i, k_cache in enumerate(k_caches)
+ ]
+ v_caches = [
+ torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
+ for i, v_cache in enumerate(v_caches)
+ ]
+
+ pos += 1
+ atten_mask[0][-pos - 1] = 0
+ return (atten_mask, pos, k_caches, v_caches)
+
+
def _kv_calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
- tokenizer_model_path="tokenizer.model",
+ tokenizer,
max_seq_len=512,
+ updator=smart_mask_updator,
+ use_i64_token=False,
):
- sp_model = get_tokenizer(tokenizer_model_path)
_, atten_mask, _, k_caches, v_caches = example_inputs
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int32)
max_cache_len = max_seq_len - 1
- token_list = sp_model.encode(
- user_prompts, bos=True, eos=False, allowed_special="all"
- )
+
+ token_list = []
+ # Llama2 tokenizer has no special tokens
+ if isinstance(tokenizer, SentencePieceTokenizer):
+ token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
+ elif isinstance(tokenizer, Tiktoken):
+ token_list = tokenizer.encode(
+ user_prompts, bos=True, eos=False, allowed_special="all"
+ )
+ else:
+ raise RuntimeError("Unkown tokenizer")
with torch.no_grad():
- while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
+ while token_list[-1] != tokenizer.eos_id and pos < max_cache_len:
+ dtype = torch.int64 if use_i64_token else torch.int32
+ token = torch.full((1, 1), token_list[pos], dtype=dtype)
logits, new_k_caches, new_v_caches = module(
- torch.full((1, 1), token_list[pos], dtype=torch.int32),
+ token,
atten_mask,
torch.full((1, 1), pos),
*k_caches,
*v_caches,
)
- k_caches = [
- torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
- for i, k_cache in enumerate(k_caches)
- ]
- v_caches = [
- torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
- for i, v_cache in enumerate(v_caches)
- ]
-
- pos += 1
- atten_mask[0][-pos - 1] = 0
+ atten_mask, pos, k_caches, v_caches = updator(
+ atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
+ )
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
- print(f"calibration data:\n{sp_model.decode(token_list)}")
+ print(f"kv calibration data:\n{tokenizer.decode(token_list)}")
def _prefill_calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
- tokenizer_model_path="tokenizer.model",
+ tokenizer,
max_seq_len=512,
+ use_i64_token=False,
):
- sp_model = get_tokenizer(tokenizer_model_path)
_, atten_mask = example_inputs
max_cache_len = max_seq_len - 1
# TODO: change criteria & support batch inputs if necessary
- token_list = sp_model.encode(
- user_prompts, bos=True, eos=False, allowed_special="all"
- )
+
+ token_list = []
+ # Llama2 tokenizer has no special tokens
+ if isinstance(tokenizer, SentencePieceTokenizer):
+ token_list = tokenizer.encode(user_prompts, bos=True, eos=False)
+ elif isinstance(tokenizer, Tiktoken):
+ token_list = tokenizer.encode(
+ user_prompts, bos=True, eos=False, allowed_special="all"
+ )
+ else:
+ raise RuntimeError("Unkown tokenizer")
+
pos = len(token_list)
+ dtype = torch.int64 if use_i64_token else torch.int32
with torch.no_grad():
- while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
- tmp_token_list = torch.tensor(token_list).reshape(1, -1)
+ while token_list[-1] != tokenizer.eos_id and pos < max_cache_len:
+ tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1)
if pos < max_cache_len:
tmp_token_list = torch.cat(
[
tmp_token_list,
- torch.zeros((1, max_cache_len - pos), dtype=torch.int32),
+ torch.zeros((1, max_cache_len - pos), dtype=dtype),
],
dim=1,
)
@@ -144,31 +206,36 @@ def _prefill_calibrate(
token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item())
pos += 1
- print(f"calibration data:\n{sp_model.decode(token_list)}")
+ print(f"prefill calibration data:\n{tokenizer.decode(token_list)}")
def calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
- tokenizer_model_path="tokenizer.model",
+ tokenizer,
max_seq_len=512,
+ kv_updator=smart_mask_updator,
+ use_i64_token=False,
):
if len(example_inputs) == 2:
_prefill_calibrate(
example_inputs,
user_prompts,
module,
- tokenizer_model_path,
+ tokenizer,
max_seq_len,
+ use_i64_token,
)
elif len(example_inputs) == 5:
_kv_calibrate(
example_inputs,
user_prompts,
module,
- tokenizer_model_path,
+ tokenizer,
max_seq_len,
+ updator=kv_updator,
+ use_i64_token=use_i64_token,
)
else:
raise RuntimeError("Get wrong inputs")
@@ -190,6 +257,7 @@ def __init__(self, llama_model, pte_filename) -> None:
else:
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
self.inputs = (tokens, atten_mask)
+ self.llama_graph_module = llama_model
def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type):
if not self.has_quant_io:
@@ -280,7 +348,7 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type):
return quant_attrs
- def quantize(self, quant_dtype, args, custom_annotations=()):
+ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
self.quant_dtype = quant_dtype
quantizer = make_quantizer(
quant_dtype=quant_dtype,
@@ -295,19 +363,22 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
with torch.no_grad():
fx_graph_module = torch.export.export(
- self.llama_model, self.inputs, strict=True
+ self.llama_graph_module, self.inputs, strict=True
).module()
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
+
logging.info("Quantizing the model...")
calibrate(
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
args.prompt,
fx_graph_module,
- tokenizer_model_path=args.tokenizer_model,
+ tokenizer=tokenizer,
max_seq_len=self.llama_meta["get_max_seq_len"],
+ kv_updator=args.kv_updator,
+ use_i64_token=args.embedding_quantize is not None,
)
- self.llama_model = convert_pt2e(fx_graph_module)
+ self.llama_graph_module = convert_pt2e(fx_graph_module)
def lowering_modules(
self,
@@ -315,7 +386,9 @@ def lowering_modules(
fixed_point_type,
use_fp16=False,
soc_model=QcomChipset.SM8650,
- num_sharding=0,
+ num_sharding=1,
+ passes_job=OrderedDict(),
+ shared_buffer=False,
):
executorch_config = ExecutorchBackendConfig(
# For shared buffer, user must pass the memory address
@@ -331,22 +404,24 @@ def lowering_modules(
with torch.no_grad():
# backend option
backend_options = generate_htp_compiler_spec(
- use_fp16=use_fp16, use_multi_contexts=num_sharding > 0
+ use_fp16=use_fp16, use_multi_contexts=num_sharding > 1
)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=soc_model,
backend_options=backend_options,
- shared_buffer=False,
+ shared_buffer=shared_buffer,
)
skip_node_op_set = {"llama.fallback.default"}
partitioner = QnnPartitioner(
compiler_specs, skip_node_op_set=skip_node_op_set
)
edge_prog = capture_program(
- self.llama_model, self.inputs, custom_pass_config=frozenset()
+ self.llama_graph_module,
+ self.inputs,
+ passes_job,
)
- if num_sharding > 0:
+ if num_sharding > 1:
model_sharding.split_graph(
edge_prog.exported_program,
self.llama_meta["get_n_layers"],
@@ -363,10 +438,10 @@ def lowering_modules(
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
- if num_sharding > 0:
+ if num_sharding > 1:
update_spill_fill_size(edge_prog_mgr.exported_program())
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
- with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
+ with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
exec_prog_mgr.write_to_file(file)
def get_example_inputs(self, use_kv_cache=True):
@@ -376,7 +451,7 @@ def get_quant_attrs(self):
return self.quant_attrs
-def compile(args, pte_filename):
+def compile(args, pte_filename, tokenizer):
os.makedirs(args.artifact, exist_ok=True)
start_ts = time.time()
@@ -396,24 +471,37 @@ def compile(args, pte_filename):
)
llama_instance_list = []
+ use_i64_token = args.embedding_quantize is not None
with torch.device("meta"):
if args.model_mode == "kv":
llama_instance_list.append(
- LlamaModel(kv_config, output_new_cache_only=True)
+ LlamaModel(
+ kv_config, output_new_cache_only=True, use_i64_token=use_i64_token
+ )
)
elif args.model_mode == "prefill":
llama_instance_list.append(
- LlamaModel(prefill_config, output_new_cache_only=False)
+ LlamaModel(
+ prefill_config,
+ output_new_cache_only=False,
+ use_i64_token=use_i64_token,
+ )
)
elif args.model_mode == "hybrid":
llama_instance_list.append(
- LlamaModel(prefill_config, output_new_cache_only=False)
+ LlamaModel(
+ kv_config, output_new_cache_only=True, use_i64_token=use_i64_token
+ )
)
llama_instance_list.append(
- LlamaModel(kv_config, output_new_cache_only=True)
+ LlamaModel(
+ prefill_config,
+ output_new_cache_only=False,
+ use_i64_token=use_i64_token,
+ )
)
else:
- raise RuntimeError(f"No such model_mode {args.model_mode}.")
+ raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
if "model" in state_dict:
state_dict = state_dict["model"]
@@ -452,6 +540,7 @@ def compile(args, pte_filename):
assert args.tokenizer_model is not None, "Need tokenizer model for calibration"
+ passes_job = get_capture_program_passes()
if args.dtype_override is not None:
dtype_override = DType[args.dtype_override]
for i in range(len(llama_instance_list)):
@@ -460,6 +549,13 @@ def compile(args, pte_filename):
)
for i in range(len(llama_instance_list)):
+ if args.embedding_quantize:
+ llama_instance_list[i] = get_quant_embedding_transform(args)(
+ llama_instance_list[i]
+ )
+ passes_job[I64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY]["skip_node"] = {
+ "tokens"
+ }
llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i])
llama_instance_list[i] = SingleLlama(
llama_instance_list[i].eval(), pte_filename
@@ -467,44 +563,68 @@ def compile(args, pte_filename):
if args.ptq:
start_quantize_ts = time.time()
- for llama_instance in llama_instance_list:
- llama_instance.quantize(
- quant_dtype=quant_dtype,
- args=args,
- custom_annotations=(
- partial(
- annotate_matmul_16a8w,
- traverse_input1=llama_instance.llama_meta["get_use_kv_cache"],
- ),
- ),
+ custom_annotations = (annotate_matmul_16a8w,)
+ if args.llama_model == "stories110m":
+ custom_annotations = custom_annotations + (
+ annotate_linear_16a8w_in_affine_layer,
)
+ if args.ptq != None:
+ kv_quant_attrs = {}
+ for i, llama_instance in enumerate(llama_instance_list):
+ llama_instance.quantize(
+ quant_dtype=quant_dtype,
+ args=args,
+ tokenizer=tokenizer,
+ custom_annotations=custom_annotations,
+ )
+ # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
+ if i == 0 and args.model_mode == "hybrid":
+ output_indices = 0
+ for node in llama_instance.llama_graph_module.graph.nodes:
+ if node.op == "output":
+ for output in node.args[0]:
+ kv_quant_attrs[output_indices] = output.args[1:]
+ output_indices += 1
+ break
+ custom_annotations = custom_annotations + (
+ partial(
+ annotate_prefill_kv_output,
+ kv_quant_attrs=kv_quant_attrs,
+ ),
+ )
end_quantize_ts = time.time()
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
start_lowering_ts = time.time()
quant_attrs = None
- if len(llama_instance_list) == 1:
+ if args.model_mode in ["kv", "prefill"]:
llama_instance_list[0].lowering_modules(
args.artifact,
fixed_point_type,
use_fp16=use_fp16,
soc_model=get_soc_to_chipset_map()[args.model],
num_sharding=args.num_sharding,
+ passes_job=passes_job,
+ shared_buffer=args.shared_buffer,
)
quant_attrs = llama_instance_list[0].get_quant_attrs()
- else:
+ elif args.model_mode == "hybrid":
sample_inputs_list = [
llama_instace.inputs for llama_instace in llama_instance_list
]
edge_progs = [
- capture_program(llama_instance.llama_model, sample_input)
+ capture_program(
+ llama_instance.llama_graph_module,
+ sample_input,
+ passes_job=passes_job,
+ )
for llama_instance, sample_input in zip(
llama_instance_list, sample_inputs_list
)
]
- if args.num_sharding > 0:
+ if args.num_sharding > 1:
for i in range(len(llama_instance_list)):
model_sharding.split_graph(
edge_progs[i].exported_program,
@@ -518,14 +638,14 @@ def compile(args, pte_filename):
fixed_point_type,
)
backend_options = generate_htp_compiler_spec(
- use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 0
+ use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1
)
- graph_names = ["prefill_forward", "kv_forward"]
+ graph_names = ["kv_forward", "prefill_forward"]
compiler_specs = [
generate_qnn_executorch_compiler_spec(
soc_model=get_soc_to_chipset_map()[args.model],
backend_options=backend_options,
- shared_buffer=True,
+ shared_buffer=args.shared_buffer,
multiple_graphs=True,
graph_name=graph_name,
)
@@ -539,9 +659,13 @@ def compile(args, pte_filename):
)
for i, edge_prog in enumerate(edge_progs)
]
- if args.num_sharding > 0:
- for exported_program in exported_programs:
- update_spill_fill_size(exported_program)
+ if args.num_sharding > 1:
+ max_sf_size = update_spill_fill_size(exported_programs)
+ qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value)
+ qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = (
+ max_sf_size
+ )
+ compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options)
executorch_config = ExecutorchBackendConfig(
# For shared buffer, user must pass the memory address
@@ -555,6 +679,7 @@ def compile(args, pte_filename):
extract_delegate_segments=True,
)
+ bundle_progs_list = []
lower_module_dict = {name: [] for name in graph_names}
call_delegate_inputs_dict = {name: [] for name in graph_names}
call_delegate_node_name_dict = {name: [] for name in graph_names}
@@ -570,11 +695,17 @@ def compile(args, pte_filename):
call_delegate_inputs_list = []
for arg in node.args:
if arg.op == "call_function":
- while "getitem" not in arg.name:
- arg = arg.args[0]
- call_delegate_inputs_list.append(
- (arg.args[0].name, arg.args[1])
- )
+ if (
+ arg.target
+ == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype
+ ):
+ call_delegate_inputs_list.append((arg.name, None))
+ else:
+ while "getitem" not in arg.name:
+ arg = arg.args[0]
+ call_delegate_inputs_list.append(
+ (arg.args[0].name, arg.args[1])
+ )
elif arg.op == "placeholder":
call_delegate_inputs_list.append((arg.name, None))
# No extra needs to do for get_attr node
@@ -584,95 +715,59 @@ def compile(args, pte_filename):
elif node.op == "output":
for arg in node.args[0]:
outputs_dict[graph_name].append((arg.args[0].name, arg.args[1]))
-
- if args.num_sharding > 0:
- bundle_progs_list = []
- for num in range(args.num_sharding - 1, -1, -1):
- processed_bytes = []
- for prog, graph_name in zip(exported_programs, graph_names):
- processed_bytes.append(
- getattr(
- prog.graph_module, f"lowered_module_{num}"
- ).processed_bytes
- )
-
- call_delegate_node = [
- list(node.users.keys())[0]
- for node in prog.graph_module.graph.nodes
- if node.op == "get_attr"
- and node.name == f"lowered_module_{num}"
- ]
- input_nodes_dict[graph_name] = [
- node
- for node in call_delegate_node[0].args
- if node.op == "placeholder"
- ]
-
- prog_mgr, bundle_progs = generate_multi_graph_program(
- compiler_specs=compiler_specs[0],
- processed_bytes=processed_bytes,
- input_nodes_dict=input_nodes_dict,
- backend_config=executorch_config,
- constant_methods=llama_instance_list[
- 1
- ].llama_meta, # kv method meta
- )
- bundle_progs_list.append(bundle_progs)
- for graph_name in graph_names:
- lower_module_dict[graph_name].append(
- prog_mgr.exported_program(graph_name).graph_module._modules.get(
- "lowered_module_0"
- )
- )
-
- exec_prog = generate_composite_llama_program(
- graph_names=graph_names,
- sample_inputs_list=sample_inputs_list,
- lower_module_dict=lower_module_dict,
- call_delegate_node_name_dict=call_delegate_node_name_dict,
- call_delegate_inputs_dict=call_delegate_inputs_dict,
- outputs_dict=outputs_dict,
- backend_config=executorch_config,
- constant_methods=llama_instance_list[1].llama_meta, # kv method meta
- )
- with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
- exec_prog.write_to_file(file)
- else:
+ for num in range(args.num_sharding - 1, -1, -1):
processed_bytes = []
- input_nodes_dict = {name: [] for name in graph_names}
- output_nodes_dict = {name: [] for name in graph_names}
for prog, graph_name in zip(exported_programs, graph_names):
processed_bytes.append(
- prog.graph_module.lowered_module_0.processed_bytes
+ getattr(prog.graph_module, f"lowered_module_{num}").processed_bytes
)
- input_nodes_dict[graph_name] = [
- node
+ call_delegate_node = [
+ list(node.users.keys())[0]
for node in prog.graph_module.graph.nodes
- if node.op == "placeholder"
+ if node.op == "get_attr" and node.name == f"lowered_module_{num}"
]
- output_nodes_dict[graph_name] = [
+ input_nodes_dict[graph_name] = [
node
- for node in prog.graph_module.graph.nodes
- if node.op == "output"
+ for node in call_delegate_node[0].args
+ if node.op == "placeholder"
+ or node.target
+ == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype
]
-
- prog_mgr, _ = generate_multi_graph_program(
+ prog_mgr, bundle_progs = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
processed_bytes=processed_bytes,
input_nodes_dict=input_nodes_dict,
- output_nodes_dict=output_nodes_dict,
backend_config=executorch_config,
- constant_methods=llama_instance_list[1].llama_meta, # kv method meta
+ constant_methods=llama_instance_list[0].llama_meta, # kv method meta
)
- with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
- prog_mgr.write_to_file(file)
+ bundle_progs_list.append(bundle_progs)
+ for graph_name in graph_names:
+ lower_module_dict[graph_name].append(
+ prog_mgr.exported_program(graph_name).graph_module._modules.get(
+ "lowered_module_0"
+ )
+ )
+ exec_prog = generate_composite_llama_program(
+ llama_model=llama_instance_list[1].llama_model,
+ graph_names=graph_names,
+ sample_inputs_list=sample_inputs_list,
+ lower_module_dict=lower_module_dict,
+ call_delegate_node_name_dict=call_delegate_node_name_dict,
+ call_delegate_inputs_dict=call_delegate_inputs_dict,
+ outputs_dict=outputs_dict,
+ embedding_quantize=args.embedding_quantize,
+ backend_config=executorch_config,
+ constant_methods=llama_instance_list[1].llama_meta, # kv method meta
+ )
+ with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
+ exec_prog.write_to_file(file)
end_lowering_ts = time.time()
logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}")
return quant_attrs
-def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
+def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
if args.model_mode == "prefill":
@@ -682,14 +777,14 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
elif args.model_mode == "hybrid":
eval_mode = 2
else:
- raise RuntimeError(f"No such model_mode {args.model_mode}.")
+ raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
seq_len = args.prefill_seq_len if args.model_mode == "prefill" else args.kv_seq_len
runner_args = " ".join(
[
f"--model_path {pte_filename}.pte",
"--output_path outputs/outputs.txt",
- f"--tokenizer_path {os.path.basename(args.tokenizer_model)}",
+ f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}",
f'--prompt "{args.prompt}"',
f"--seq_len {seq_len}",
f"--eval_mode {eval_mode}",
@@ -697,12 +792,13 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
f"--system_prompt '{args.system_prompt}'",
f"--logits_scale {quant_attrs['scale']}",
f"--logits_offset {quant_attrs['zero_point']}",
+ f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
]
)
runner_cmd = " ".join(
[
f"cd {workspace} &&",
- f"./qnn_llama3_2_runner {runner_args}",
+ f"./qnn_llama_runner {runner_args}",
]
)
@@ -720,10 +816,10 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
host_id=args.host,
soc_model=args.model,
shared_buffer=args.shared_buffer,
- runner=f"examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner",
+ runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner",
)
# No pregen inputs, input_list is not required
- adb.push(inputs=[], input_list="", files=[args.tokenizer_model])
+ adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path])
adb.execute(custom_runner_cmd=runner_cmd)
# collect output data
@@ -751,13 +847,13 @@ def post_process():
logging.info(f"Results[{idx}]:\n{output}")
-def main():
+def _build_parser():
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
"--artifact",
- help="path for storing generated artifacts and output by this example. Default ./llama3_2_qnn",
- default="./llama3_2_qnn",
+ help="path for storing generated artifacts and output by this example. Default ./llama_qnn",
+ default="./llama_qnn",
type=str,
)
@@ -768,6 +864,13 @@ def main():
type=str,
)
+ parser.add_argument(
+ "--llama_model",
+ choices=["stories110m", "llama3_2"],
+ help="The Llama model to export. Current available options are: [stories110m, llama3_2]",
+ required=True,
+ )
+
parser.add_argument(
"--checkpoint",
help="Pass llama checkpoint.",
@@ -783,10 +886,9 @@ def main():
)
parser.add_argument(
- "--model_size",
- help="Determine what runner be used. For llama 3.2, we only support 1B/3B. ",
- choices=["1B", "3B"],
- required=True,
+ "--tokenizer_bin",
+ help="For Llama2. Pass Llama2 tokenizer binary.",
+ required=False,
type=str,
)
@@ -806,7 +908,7 @@ def main():
parser.add_argument(
"--system_prompt",
- help="Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None",
+ help="For Llama3. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None",
default="",
type=str,
)
@@ -829,14 +931,14 @@ def main():
parser.add_argument(
"--pre_gen_pte",
- help="Run the Pre-generated llama in the given directory",
+ help="Run the pre-generated llama in the given directory.",
type=str,
)
parser.add_argument(
"--num_sharding",
type=int,
- default=0,
+ default=1,
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
)
@@ -862,31 +964,81 @@ def main():
type=int,
)
- args = parser.parse_args()
+ parser.add_argument(
+ "--kv_updator",
+ help="Choose how to update kv cache during runtime",
+ choices=["smart_mask", "shift_pointer"],
+ default="smart_mask",
+ type=str,
+ )
+
+ parser.add_argument(
+ "-E",
+ "--embedding-quantize",
+ default=None,
+ type=str,
+ help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.",
+ )
+
+ return parser
+
+
+def main(args) -> None:
+ parser = _build_parser()
+
+ args = parser.parse_args(args)
if args.compile_only and args.pre_gen_pte:
exit("Cannot set both compile_only and pre_gen_pte as true")
if args.model_mode == "kv":
- pte_filename = "kv_llama3_2_qnn"
+ pte_filename = "kv_llama_qnn"
elif args.model_mode == "prefill":
- pte_filename = "prefill_llama3_2_qnn"
+ pte_filename = "prefill_llama_qnn"
elif args.model_mode == "hybrid":
assert (
args.kv_seq_len >= args.prefill_seq_len
), "Please ensure kv_seq_len is >= prefill_seq_len"
- pte_filename = "hybrid_llama3_2_qnn"
+ pte_filename = "hybrid_llama_qnn"
else:
- raise RuntimeError(f"No such model_mode {args.model_mode}.")
+ raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
+
+ tokenizer = get_tokenizer(args.tokenizer_model)
+ runtime_tokenizer_path = ""
+ if args.llama_model == "stories110m":
+ assert isinstance(
+ tokenizer, SentencePieceTokenizer
+ ), f"Wrong tokenizer provided for stories110m."
+ assert (
+ args.tokenizer_bin is not None
+ ), "Please provide tokenizer_bin for stories110m."
+ runtime_tokenizer_path = args.tokenizer_bin
+ elif args.llama_model == "llama3_2":
+ assert isinstance(
+ tokenizer, Tiktoken
+ ), f"Wrong tokenizer provided for llama3_2."
+ runtime_tokenizer_path = args.tokenizer_model
+ else:
+ raise RuntimeError(f"Unknown llama_model: {args.llama_model}.")
+
+ if args.kv_updator == "smart_mask":
+ args.shared_buffer = True
+ args.kv_updator = smart_mask_updator
+ elif args.kv_updator == "shift_pointer":
+ args.kv_updator = shift_pointer_updator
+ else:
+ exit(f"Using an unkown kv update {args.kv_updator}")
if args.pre_gen_pte:
quant_attrs = json.load(
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")
)
- inference(args, quant_attrs, pte_filename, args.pre_gen_pte)
+ inference(
+ args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte
+ )
exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
if args.compile_only:
- quant_attrs = compile(args, pte_filename)
+ quant_attrs = compile(args, pte_filename, tokenizer)
if quant_attrs:
json.dump(
{
@@ -900,7 +1052,7 @@ def main():
exit(f"Finish compile_only and save to {args.artifact}")
try:
- quant_attrs = compile(args, pte_filename)
+ quant_attrs = compile(args, pte_filename, tokenizer)
if quant_attrs:
logging.info(
f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}"
@@ -914,7 +1066,7 @@ def main():
)
else:
logging.warning("Quant attributes of the logit is None.")
- inference(args, quant_attrs, pte_filename)
+ inference(args, quant_attrs, pte_filename, runtime_tokenizer_path)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
@@ -925,4 +1077,4 @@ def main():
# flake8: noqa: C901
if __name__ == "__main__":
- main()
+ main(sys.argv[1:])
diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py
similarity index 97%
rename from examples/qualcomm/oss_scripts/llama2/model/static_llama.py
rename to examples/qualcomm/oss_scripts/llama/model/static_llama.py
index d1b618ed07..253abc9578 100755
--- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py
+++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py
@@ -12,10 +12,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from executorch.examples.models.llama.llama_transformer import (
- ModelArgs,
- precompute_freqs_cis,
-)
+from executorch.examples.models.llama.model_args import ModelArgs
+from executorch.examples.models.llama.rope import precompute_freqs_cis
def apply_rotary_emb_single(
@@ -299,7 +297,9 @@ def forward(
class LlamaModel(nn.Module):
- def __init__(self, config: ModelArgs, output_new_cache_only=True):
+ def __init__(
+ self, config: ModelArgs, output_new_cache_only=True, use_i64_token=False
+ ):
super().__init__()
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
@@ -312,6 +312,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
self.rope_freq_base = config.rope_freq_base
self.use_kv_cache = config.use_kv_cache
self.output_new_cache_only = output_new_cache_only
+ self.use_i64_token = use_i64_token
self.layers = nn.ModuleList(
[
@@ -390,10 +391,12 @@ def forward(
return logits, output_k_cache, output_v_cache
def get_example_inputs(self, use_kv_cache=True):
+ dtype = torch.int64 if self.use_i64_token else torch.int32
if use_kv_cache:
tokens = torch.randint(
- self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32
+ self.vocab_size, (self.max_batch_size, 1), dtype=dtype
)
+
pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32)
k_cache, v_cache = [], []
atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0)
@@ -424,7 +427,7 @@ def get_example_inputs(self, use_kv_cache=True):
)
max_promp = self.max_seq_len - 1
- tokens = torch.arange(0, max_promp, 1, dtype=torch.int32).unsqueeze(0)
+ tokens = torch.arange(0, max_promp, 1, dtype=dtype).unsqueeze(0)
atten_mask = torch.triu(torch.rand((max_promp, max_promp)), 1)
atten_mask[atten_mask != 0] = -255
return (
diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
similarity index 83%
rename from examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp
rename to examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
index 2af882580e..1bc90a11f9 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp
+++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
@@ -9,12 +9,13 @@
/**
* @file
*
- * This tool can run Llama3.2 1B/3B with Qualcomm AI Engine Direct.
+ * This tool can run Llama2 110M, Llama3.2 1B / 3B(WIP) with Qualcomm AI Engine
+ * Direct.
*
*/
#include
-#include
+#include
#include
#include
#include
@@ -22,7 +23,7 @@
DEFINE_string(
model_path,
- "qnn_llama2.pte",
+ "kv_llama_qnn.pte",
"Model serialized in flatbuffer format.");
DEFINE_string(
@@ -42,14 +43,18 @@ DEFINE_double(
DEFINE_int32(
seq_len,
128,
- "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
+ "Total number of tokens to generate (prompt + output).");
DEFINE_int32(
eval_mode,
- 0,
+ 1,
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
DEFINE_double(logits_scale, 0.0, "Logits scale");
DEFINE_int32(logits_offset, 0, "Logits offset");
+DEFINE_string(
+ kv_updator,
+ "How to update kv cache. Choose between SmartMask and ShiftPointer",
+ "SmartMask");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -61,7 +66,8 @@ int main(int argc, char** argv) {
FLAGS_logits_scale,
FLAGS_logits_offset,
FLAGS_temperature,
- FLAGS_eval_mode);
+ FLAGS_eval_mode,
+ FLAGS_kv_updator);
std::vector buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp
new file mode 100644
index 0000000000..7992913a58
--- /dev/null
+++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp
@@ -0,0 +1,1126 @@
+/*
+ * Copyright (c) Qualcomm Innovation Center, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+using executorch::aten::Tensor;
+using executorch::aten::TensorImpl;
+using executorch::extension::Module;
+using executorch::runtime::Error;
+using executorch::runtime::MemoryAllocator;
+using executorch::runtime::MethodMeta;
+using executorch::runtime::Result;
+using executorch::runtime::TensorInfo;
+
+namespace example {
+
+IoMgrBase::IoMgrBase(std::vector>& modules)
+ : data_ptr_(nullptr, [](void*) {}), modules_(modules) {}
+
+IoMgrBase::~IoMgrBase() {}
+
+void* IoMgrBase::get_mutable_ptr() {
+ return data_ptr_.get();
+}
+
+std::vector IoMgrBase::get_input_tensors(
+ int shard_index,
+ const std::string& method_name) {
+ std::vector ret;
+ ret.reserve(input_tensors_.size());
+ for (TensorImpl* impl : input_tensors_[method_name][shard_index]) {
+ ret.emplace_back(Tensor(impl));
+ }
+ return ret;
+}
+
+std::vector IoMgrBase::get_output_tensors(
+ int shard_index,
+ const std::string& method_name) {
+ std::vector ret;
+ ret.reserve(output_tensors_[method_name][shard_index].size());
+ for (TensorImpl* impl : output_tensors_[method_name][shard_index]) {
+ ret.emplace_back(Tensor(impl));
+ }
+ return ret;
+}
+
+ShiftPointerIoMgr::ShiftPointerIoMgr(
+ std::vector>& modules,
+ int32_t prefill_cache_len,
+ int32_t kv_cache_len,
+ int32_t vocab_size,
+ int32_t num_layers,
+ int32_t head_dim,
+ int32_t num_heads,
+ EvalMode eval_mode,
+ const std::string& prefill_forward_name,
+ const std::string& kv_forward_name,
+ const bool use_int64_token)
+ : IoMgrBase(modules),
+ shard_layers_({num_layers}),
+ kv_cache_len_(kv_cache_len),
+ prefill_cache_len_(prefill_cache_len),
+ vocab_size_(vocab_size),
+ num_layers_(num_layers),
+ head_dim_(head_dim),
+ num_heads_(num_heads),
+ eval_mode_(eval_mode),
+ prefill_forward_name_(prefill_forward_name),
+ kv_forward_name_(kv_forward_name),
+ use_int64_token_(use_int64_token) {
+ if (!prefill_forward_name_.empty()) {
+ input_tensors_[prefill_forward_name_] =
+ std::vector>(modules.size());
+ output_tensors_[prefill_forward_name_] =
+ std::vector>(modules.size());
+ k_cache_in_[prefill_forward_name_] =
+ std::vector>();
+ v_cache_in_[prefill_forward_name_] =
+ std::vector>();
+ k_cache_out_[prefill_forward_name_] =
+ std::vector>();
+ v_cache_out_[prefill_forward_name_] =
+ std::vector>();
+ }
+ if (!kv_forward_name_.empty()) {
+ input_tensors_[kv_forward_name_] =
+ std::vector>(modules.size());
+ output_tensors_[kv_forward_name_] =
+ std::vector