From ffce22ade64b9e090014594f7b417633326add58 Mon Sep 17 00:00:00 2001 From: Uros Males Date: Mon, 27 Jan 2025 13:36:16 +0100 Subject: [PATCH 01/13] Tests for dot_general op (#175) --- tests/jax/ops/test_dot_general.py | 46 +++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/jax/ops/test_dot_general.py b/tests/jax/ops/test_dot_general.py index b8f62bb1..d8001934 100644 --- a/tests/jax/ops/test_dot_general.py +++ b/tests/jax/ops/test_dot_general.py @@ -8,17 +8,51 @@ from infra import run_op_test_with_random_inputs +# Tests for dot_general op where vectors containing indices of contracting dimensions +# are of size 1 and are equal. In training models, besides cases that correspond to matmul, +# this is the most common one we have. @pytest.mark.parametrize( ["x_shape", "y_shape"], [ - [(32, 32), (32, 32)], - [(64, 64), (64, 64)], - [(32, 64), (64, 32)], - [(64, 32), (32, 64)], + [(1, 32), (1, 32)], + [(1, 32, 64), (1, 32, 32)], + [(2, 32, 64), (2, 32, 64)], + [(2, 16, 32, 64), (2, 16, 64, 32)], ], ) -def test_dot_general(x_shape: tuple, y_shape: tuple): +def test_dot_general_common(x_shape: tuple, y_shape: tuple): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: - return jnp.dot(x, y) + return jax.lax.dot_general(x, y, dimension_numbers=((1, 1), (0, 0))) + + run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) + + +# Tests for dot_general op where this operation corresponds to regular matmul. +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(1, 32, 64), (1, 64, 32)], + [(2, 32, 64), (2, 64, 64)], + ], +) +def test_dot_general_matmul(x_shape: tuple, y_shape: tuple): + def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: + return jax.lax.dot_general(x, y, dimension_numbers=((2, 1), (0, 0))) + + run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) + + +# Tests for dot_general op where vectors containing indices of +# contracting dimensions are of size greater than 1. +@pytest.mark.parametrize( + ["x_shape", "y_shape"], + [ + [(1, 16, 16, 8), (1, 16, 8, 16)], + [(2, 8, 8, 16), (2, 8, 16, 8)], + ], +) +def test_dot_general_multiple_contract(x_shape: tuple, y_shape: tuple): + def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: + return jax.lax.dot_general(x, y, dimension_numbers=(((1, 3), (1, 2)), (0, 0))) run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) From 35b2f4d8a9b94f9285a0defc1df1d16f90863607 Mon Sep 17 00:00:00 2001 From: Marko Rakita Date: Mon, 27 Jan 2025 16:19:23 +0100 Subject: [PATCH 02/13] Add BART model tests (#191) Added tests for HF BART base and large model variant with a language modeling head on top. --- tests/jax/models/bart/__init__.py | 0 tests/jax/models/bart/base/__init__.py | 0 tests/jax/models/bart/base/test_bart_base.py | 40 +++++++++++++++++ tests/jax/models/bart/large/__init__.py | 0 .../jax/models/bart/large/test_bart_large.py | 40 +++++++++++++++++ tests/jax/models/bart/tester.py | 43 +++++++++++++++++++ 6 files changed, 123 insertions(+) create mode 100644 tests/jax/models/bart/__init__.py create mode 100644 tests/jax/models/bart/base/__init__.py create mode 100644 tests/jax/models/bart/base/test_bart_base.py create mode 100644 tests/jax/models/bart/large/__init__.py create mode 100644 tests/jax/models/bart/large/test_bart_large.py create mode 100644 tests/jax/models/bart/tester.py diff --git a/tests/jax/models/bart/__init__.py b/tests/jax/models/bart/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/base/__init__.py b/tests/jax/models/bart/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/base/test_bart_base.py b/tests/jax/models/bart/base/test_bart_base.py new file mode 100644 index 00000000..d6989251 --- /dev/null +++ b/tests/jax/models/bart/base/test_bart_base.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxBartForCausalLMTester + +MODEL_PATH = "facebook/bart-base" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +def test_flax_bart_base_inference( + inference_tester: FlaxBartForCausalLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bart_base_training( + training_tester: FlaxBartForCausalLMTester, +): + training_tester.test() diff --git a/tests/jax/models/bart/large/__init__.py b/tests/jax/models/bart/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bart/large/test_bart_large.py b/tests/jax/models/bart/large/test_bart_large.py new file mode 100644 index 00000000..142641dc --- /dev/null +++ b/tests/jax/models/bart/large/test_bart_large.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxBartForCausalLMTester + +MODEL_PATH = "facebook/bart-large" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxBartForCausalLMTester: + return FlaxBartForCausalLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +def test_flax_bart_large_inference( + inference_tester: FlaxBartForCausalLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_bart_large_training( + training_tester: FlaxBartForCausalLMTester, +): + training_tester.test() diff --git a/tests/jax/models/bart/tester.py b/tests/jax/models/bart/tester.py new file mode 100644 index 00000000..282c163c --- /dev/null +++ b/tests/jax/models/bart/tester.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxBartForCausalLM + + +class FlaxBartForCausalLMTester(ModelTester): + """Tester for BART model variants with a language modeling head on top.""" + + # TODO(mrakita): Add tests for other variants. + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxBartForCausalLM.from_pretrained(self._model_name, from_pt=True) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + } From 5280a81263a427911c7f743ab7c9576a264d43d3 Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:20:33 +0100 Subject: [PATCH 03/13] Adding PR template (#188) Adding PR template, This addition aims to help contributors consistently describe their changes and ensure our PRs remain well-documented. --- .github/pull_request_template.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/pull_request_template.md diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..4a91d818 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,12 @@ +### Ticket +Link to Github Issue + +### Problem description +Provide context for the problem. + +### What's changed +Describe the approach used to solve the problem. +Summarize the changes made and its impact. + +### Checklist +- [ ] New/Existing tests provide coverage for changes From ad10ac621526193e22b4616be2ca0ebaee48103e Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Tue, 28 Jan 2025 10:33:38 +0100 Subject: [PATCH 04/13] Correct the tt-mlir path in get-docker-tag.sh (#190) Script get-docker-tag.sh was incorrectly generating docker tag, it didn't take into account tt-mlir docker tag change. Correct the path to match path where cmake ExternalProject_Add clones tt-mlir to gix the docker tag. --- .github/get-docker-tag.sh | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/.github/get-docker-tag.sh b/.github/get-docker-tag.sh index 9ed9f0f7..8dea07ba 100755 --- a/.github/get-docker-tag.sh +++ b/.github/get-docker-tag.sh @@ -8,26 +8,24 @@ # Exit immediately if a command exits with a non-zero status set -e - -# Execute this in a separate bash process -( - # Read tt-mlir version from third_party/CMakeLists.txt and clone third_party/tt-mlir +MLIR_DOCKER_TAG=$( + # Read tt-mlir version from third_party/CMakeLists.txt + # clone tt-mlir version to tmp/third_party/tt-mlir # Get the MLIR docker tag + TT_MLIR_PATH=tmp/third_party/tt-mlir TT_MLIR_VERSION=$(grep -oP 'set\(TT_MLIR_VERSION "\K[^"]+' third_party/CMakeLists.txt) - if [ ! -d "third_party/tt-mlir" ]; then - git clone https://github.com/tenstorrent/tt-mlir.git third_party/tt-mlir --quiet + if [ ! -d $TT_MLIR_PATH ]; then + git clone https://github.com/tenstorrent/tt-mlir.git $TT_MLIR_PATH --quiet fi - cd third_party/tt-mlir + cd $TT_MLIR_PATH git fetch --quiet git checkout $TT_MLIR_VERSION --quiet if [ -f ".github/get-docker-tag.sh" ]; then - MLIR_DOCKER_TAG=$(.github/get-docker-tag.sh) + .github/get-docker-tag.sh else - MLIR_DOCKER_TAG="default-tag" + echo "default-tag" fi - cd ../.. ) - -DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci" -DOCKERFILE_HASH=$( (echo $MLIR_DOCKER_TAG; sha256sum $DOCKERFILE_HASH_FILES) | sha256sum | cut -d ' ' -f 1) -echo dt-$DOCKERFILE_HASH +DOCKERFILE_HASH=$( (cat .github/Dockerfile.base .github/Dockerfile.ci | sha256sum) | cut -d ' ' -f 1) +COMBINED_HASH=$( (echo $DOCKERFILE_HASH $MLIR_DOCKER_TAG | sha256sum) | cut -d ' ' -f 1) +echo dt-$COMBINED_HASH From 1e2191bb16a252241ed8ce588833436ecb2ed497 Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Tue, 28 Jan 2025 15:17:00 +0100 Subject: [PATCH 05/13] Adding nightly test run (#195) ### Ticket N/A ### Problem description Adding nightly test run. For start it will run all tests same as PR workflow. Later we can add test marks and label tests as push or nightly to select which test group executes in a workflow. See https://github.com/tenstorrent/tt-forge-fe/blob/main/pytest.ini and https://github.com/tenstorrent/tt-forge-fe/blob/main/.github/workflows/build-and-test.yml ### What's changed Added nightly run scheduled workflow Added test_mark inputs for 'Build and Test' workflow ### Checklist - [x] New/Existing tests provide coverage for changes --- .github/workflows/build-and-test.yml | 13 +++++++++++++ .github/workflows/on-nightly.yml | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 .github/workflows/on-nightly.yml diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index bdf44a0a..482466b2 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -7,12 +7,25 @@ on: description: 'Git SHA of commit in tenstorrent/tt-mlir' required: false type: string + test_mark: + description: 'Test mark to run' + required: true + default: 'push' + type: choice + options: + - push + - nightly workflow_call: inputs: mlir_override: description: 'Git SHA of commit in tenstorrent/tt-mlir' required: false type: string + test_mark: + description: 'Test mark to run' + required: false + default: 'push' + type: string permissions: packages: write diff --git a/.github/workflows/on-nightly.yml b/.github/workflows/on-nightly.yml new file mode 100644 index 00000000..f1b90093 --- /dev/null +++ b/.github/workflows/on-nightly.yml @@ -0,0 +1,13 @@ +name: On nightly + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' + +jobs: + build-and-test: + uses: ./.github/workflows/build-and-test.yml + secrets: inherit + with: + test_mark: 'nightly' From 067af4e6e37a5d5abcf88da287311ce47ef96eda Mon Sep 17 00:00:00 2001 From: Stefan Gligorijevic <189116645+sgligorijevicTT@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:25:03 +0100 Subject: [PATCH 06/13] Uplift mlir and add shardy to dependencies (#194) --- src/common/CMakeLists.txt | 4 ++++ third_party/CMakeLists.txt | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index bcc3b88a..42d8eed6 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -39,8 +39,10 @@ target_include_directories(TTPJRTCommon PUBLIC ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/include/ttmlir/Target/Common ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/include ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir/runtime/include + ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/shardy/ ${PROJECT_SOURCE_DIR}/third_party/tt-mlir/src/tt-mlir-build/stablehlo/ ${TTMLIR_TOOLCHAIN_DIR}/include + ${TTMLIR_TOOLCHAIN_DIR}/src/shardy ${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo ) @@ -63,6 +65,8 @@ ChloOps Version VhloOps VhloTypes +SdyDialect +SdyRegister StablehloOps StablehloRegister StablehloReferenceToken diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 342853cf..b0ddd8b2 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "ad93092841542d104b8bacba6b92d97678ccd94e") +set(TT_MLIR_VERSION "e581e6a815de6b0ce3c4ceedc70620d19be29969") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON") From c90d7c66aa1c288b2563282a9790c9c7b84d3d5f Mon Sep 17 00:00:00 2001 From: Marko Rakita Date: Wed, 29 Jan 2025 11:53:43 +0100 Subject: [PATCH 07/13] Add roberta-large model test (#192) ### Problem description Roberta large model variant test was missing. ### What's changed - Added a roberta-large model test - Fixed few small bugs in some other model tests ### Checklist - [x] New/Existing tests provide coverage for changes --- tests/jax/models/bart/tester.py | 2 +- tests/jax/models/gpt2/test_gpt2.py | 4 +- tests/jax/models/mlpmixer/test_mlpmixer.py | 2 +- tests/jax/models/mnist/cnn/test_mnist_cnn.py | 4 +- tests/jax/models/roberta/__init__.py | 0 tests/jax/models/roberta/base/__init__.py | 0 .../models/roberta/base/test_roberta_base.py | 40 ++++++++++++++ tests/jax/models/roberta/large/__init__.py | 0 .../roberta/large/test_roberta_large.py | 40 ++++++++++++++ .../roberta/{test_roberta.py => tester.py} | 52 +++++-------------- .../models/squeezebert/test_squeezebert.py | 6 +-- 11 files changed, 103 insertions(+), 47 deletions(-) create mode 100644 tests/jax/models/roberta/__init__.py create mode 100644 tests/jax/models/roberta/base/__init__.py create mode 100644 tests/jax/models/roberta/base/test_roberta_base.py create mode 100644 tests/jax/models/roberta/large/__init__.py create mode 100644 tests/jax/models/roberta/large/test_roberta_large.py rename tests/jax/models/roberta/{test_roberta.py => tester.py} (50%) diff --git a/tests/jax/models/bart/tester.py b/tests/jax/models/bart/tester.py index 282c163c..d6681adf 100644 --- a/tests/jax/models/bart/tester.py +++ b/tests/jax/models/bart/tester.py @@ -26,7 +26,7 @@ def __init__( # @override def _get_model(self) -> nn.Module: - return FlaxBartForCausalLM.from_pretrained(self._model_name, from_pt=True) + return FlaxBartForCausalLM.from_pretrained(self._model_name) # @override def _get_input_activations(self) -> Sequence[jax.Array]: diff --git a/tests/jax/models/gpt2/test_gpt2.py b/tests/jax/models/gpt2/test_gpt2.py index 40d6c2bb..988de9d9 100644 --- a/tests/jax/models/gpt2/test_gpt2.py +++ b/tests/jax/models/gpt2/test_gpt2.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Sequence, Dict +from typing import Dict, Sequence import jax import pytest @@ -52,7 +52,7 @@ def training_tester() -> GPT2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") -def test_gp2_inference( +def test_gpt2_inference( inference_tester: GPT2Tester, ): inference_tester.test() diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 469e0688..e3fd985c 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -95,7 +95,7 @@ def training_tester() -> MlpMixerTester: "(https://github.com/tenstorrent/tt-xla/issues/187)" ) ) # segfault -def test_mlpmixer(inference_tester: MlpMixerTester): +def test_mlpmixer_inference(inference_tester: MlpMixerTester): inference_tester.test() diff --git a/tests/jax/models/mnist/cnn/test_mnist_cnn.py b/tests/jax/models/mnist/cnn/test_mnist_cnn.py index 88004ec1..912c66ac 100644 --- a/tests/jax/models/mnist/cnn/test_mnist_cnn.py +++ b/tests/jax/models/mnist/cnn/test_mnist_cnn.py @@ -68,14 +68,14 @@ def training_tester() -> MNISTCNNTester: @pytest.mark.skip( reason='void mlir::OperationConverter::finalize(mlir::ConversionPatternRewriter &): Assertion `newValue && "replacement value not found"\' failed.' ) # This is a segfault, marking it as xfail would bring down the whole test suite -def test_mnist_inference( +def test_mnist_cnn_inference( inference_tester: MNISTCNNTester, ): inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_mnist_training( +def test_mnist_cnn_training( training_tester: MNISTCNNTester, ): training_tester.test() diff --git a/tests/jax/models/roberta/__init__.py b/tests/jax/models/roberta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/base/__init__.py b/tests/jax/models/roberta/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/base/test_roberta_base.py b/tests/jax/models/roberta/base/test_roberta_base.py new file mode 100644 index 00000000..9969204a --- /dev/null +++ b/tests/jax/models/roberta/base/test_roberta_base.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxRobertaForMaskedLMTester + +MODEL_PATH = "FacebookAI/roberta-base" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") +def test_flax_roberta_base_inference( + inference_tester: FlaxRobertaForMaskedLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_roberta_base_training( + training_tester: FlaxRobertaForMaskedLMTester, +): + training_tester.test() diff --git a/tests/jax/models/roberta/large/__init__.py b/tests/jax/models/roberta/large/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/roberta/large/test_roberta_large.py b/tests/jax/models/roberta/large/test_roberta_large.py new file mode 100644 index 00000000..e1f94d03 --- /dev/null +++ b/tests/jax/models/roberta/large/test_roberta_large.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from infra import RunMode + +from ..tester import FlaxRobertaForMaskedLMTester + +MODEL_PATH = "FacebookAI/roberta-large" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> FlaxRobertaForMaskedLMTester: + return FlaxRobertaForMaskedLMTester(MODEL_PATH, RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") +def test_flax_roberta_large_inference( + inference_tester: FlaxRobertaForMaskedLMTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_roberta_large_training( + training_tester: FlaxRobertaForMaskedLMTester, +): + training_tester.test() diff --git a/tests/jax/models/roberta/test_roberta.py b/tests/jax/models/roberta/tester.py similarity index 50% rename from tests/jax/models/roberta/test_roberta.py rename to tests/jax/models/roberta/tester.py index 32542cd5..71fc3ad4 100644 --- a/tests/jax/models/roberta/test_roberta.py +++ b/tests/jax/models/roberta/tester.py @@ -5,26 +5,32 @@ from typing import Dict, Sequence import jax -import pytest from flax import linen as nn -from infra import ModelTester, RunMode +from infra import ComparisonConfig, ModelTester, RunMode from transformers import AutoTokenizer, FlaxRobertaForMaskedLM -MODEL_PATH = "FacebookAI/roberta-base" - -# ----- Tester ----- - class FlaxRobertaForMaskedLMTester(ModelTester): """Tester for Roberta model on a masked language modeling task.""" + # TODO(mrakita): Add tests for other variants. + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + # @override def _get_model(self) -> nn.Module: - return FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH) + return FlaxRobertaForMaskedLM.from_pretrained(self._model_name) # @override def _get_input_activations(self) -> Sequence[jax.Array]: - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + tokenizer = AutoTokenizer.from_pretrained(self._model_name) inputs = tokenizer("Hello .", return_tensors="np") return inputs["input_ids"] @@ -39,33 +45,3 @@ def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: # @ override def _get_static_argnames(self): return ["train"] - - -# ----- Fixtures ----- - - -@pytest.fixture -def inference_tester() -> FlaxRobertaForMaskedLMTester: - return FlaxRobertaForMaskedLMTester() - - -@pytest.fixture -def training_tester() -> FlaxRobertaForMaskedLMTester: - return FlaxRobertaForMaskedLMTester(RunMode.TRAINING) - - -# ----- Tests ----- - - -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") -def test_roberta_inference( - inference_tester: FlaxRobertaForMaskedLMTester, -): - inference_tester.test() - - -@pytest.mark.skip(reason="Support for training not implemented") -def test_flax_roberta_training( - training_tester: FlaxRobertaForMaskedLMTester, -): - training_tester.test() diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py index fe98f0cc..2362066b 100644 --- a/tests/jax/models/squeezebert/test_squeezebert.py +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -72,15 +72,15 @@ def training_tester() -> SqueezeBertTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") -def test_flax_distilbert_inference( +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +def test_squeezebert_inference( inference_tester: SqueezeBertTester, ): inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_distilbert_training( +def test_squeezebert_training( training_tester: SqueezeBertTester, ): training_tester.test() From 977757bc281b50ce69059194ddfa0b4059bea843 Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:33:10 +0100 Subject: [PATCH 08/13] Upload the code coverage and test reports to codecov (#199) ### Ticket N/A ### Problem description We want to visualize and track history of code coverage using codecov platform ### What's changed Added steps to build and test workflow to upload code coverage and test reports to codecov Upload code coverage only from n300 run Fiter coverage stats to show only tt-xla/src image ### Checklist - [x] New/Existing tests provide coverage for changes --- .github/workflows/build-and-test.yml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 482466b2..3466d8de 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -264,7 +264,25 @@ jobs: group_suite: true - name: Prepare code coverage report + if: matrix.build.runs-on == 'n300' && (success() || failure()) run: | lcov --directory build --capture --output-file coverage.info - lcov --extract coverage.info '**/src/*' --output-file coverage.info + lcov --extract coverage.info '**/tt-xla/src/*' --output-file coverage.info + sed -i 's|SF:/__w/tt-xla/tt-xla/src/|SF:src/|' coverage.info lcov --list coverage.info + + - name: Upload coverage reports to Codecov + if: matrix.build.runs-on == 'n300' && (success() || failure()) + uses: codecov/codecov-action@v5 + with: + files: coverage.info + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} + + - name: Upload test results to Codecov + if: success() || failure() + uses: codecov/test-results-action@v1 + with: + files: ${{ steps.strings.outputs.test_report_path }} + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} From 27c15d4181a975ed8de53f3943ec0b548d36fd6b Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 30 Jan 2025 09:54:23 +0100 Subject: [PATCH 09/13] Updated ttmlir version, skipped convert op tests (#208) --- tests/jax/ops/test_convert.py | 1 + third_party/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 252e62de..c4cf7ba5 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -30,6 +30,7 @@ "float64", ], ) +@pytest.mark.xfail(reason="https://github.com/tenstorrent/tt-xla/issues/206") def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike): def convert(x: jax.Array) -> jax.Array: return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype)) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index b0ddd8b2..46dcef9f 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "e581e6a815de6b0ce3c4ceedc70620d19be29969") +set(TT_MLIR_VERSION "2e45a8f777989e356ce857a59fccc21fc733ee0e") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON") From 86db9b3c4a3734951523fbec371b07c34de50804 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 30 Jan 2025 19:24:17 +0100 Subject: [PATCH 10/13] Added ttir to linalg library to cmake (#217) Uplifted ttmlir to affea5d63684658ee263a359b8904b433f5edf21 --- src/common/CMakeLists.txt | 1 + third_party/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 42d8eed6..1ca6f741 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -112,6 +112,7 @@ target_link_libraries(TTPJRTCommon PUBLIC TTPJRTCommonDylibPlatform TTMLIRStatic TTMLIRTosaToTTIR + TTMLIRTTIRToLinalg MLIRTTIRPipelines TTMLIRStableHLOToTTIR ${STABLEHLO_LIBS} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 46dcef9f..f9c5963c 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "2e45a8f777989e356ce857a59fccc21fc733ee0e") +set(TT_MLIR_VERSION "affea5d63684658ee263a359b8904b433f5edf21") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON") From 820df6525e724e669f5487ba0d0ce8f313af0974 Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:23:48 +0100 Subject: [PATCH 11/13] Collect data for nightly jobs (#213) ### Ticket N/A ### Problem description Collect data for nightly jobs ### What's changed Added nightly workflow to collect data job ### Checklist - [ x] New/Existing tests provide coverage for changes --- .github/workflows/produce_data.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/produce_data.yml b/.github/workflows/produce_data.yml index 1b26c3d0..ca089055 100644 --- a/.github/workflows/produce_data.yml +++ b/.github/workflows/produce_data.yml @@ -6,6 +6,7 @@ on: - "On PR" - "On push" - "Build and Test" + - "On nightly" types: - completed From dfd586cba74400faec5e0a3c0deb8598939f1675 Mon Sep 17 00:00:00 2001 From: Vladimir Milosevic <157983820+vmilosevic@users.noreply.github.com> Date: Fri, 31 Jan 2025 08:13:40 +0100 Subject: [PATCH 12/13] Uplift third_party/tt-mlir to 5dde571ff7a685ed92b32d7e9ddfa25d52c621cd 2025-01-31 (#184) This PR uplifts the third_party/tt-mlir to the 5dde571ff7a685ed92b32d7e9ddfa25d52c621cd Co-authored-by: kmitrovicTT <169657397+kmitrovicTT@users.noreply.github.com> --- third_party/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index f9c5963c..245cb418 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # -set(TT_MLIR_VERSION "affea5d63684658ee263a359b8904b433f5edf21") +set(TT_MLIR_VERSION "5dde571ff7a685ed92b32d7e9ddfa25d52c621cd") set(LOGURU_VERSION "4adaa185883e3c04da25913579c451d3c32cfac1") if (TOOLCHAIN STREQUAL "ON") From c1cb43a93c38fe09b092c41c278477798aeee1ec Mon Sep 17 00:00:00 2001 From: Stefan Gligorijevic <189116645+sgligorijevicTT@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:53:54 +0100 Subject: [PATCH 13/13] Update fail reasons (#215) --- tests/jax/graphs/test_softmax.py | 3 --- tests/jax/models/albert/v2/base/test_albert_base.py | 5 +++-- tests/jax/models/albert/v2/large/test_albert_large.py | 4 +++- tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py | 4 +++- tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py | 4 +++- tests/jax/models/bart/base/test_bart_base.py | 4 +++- tests/jax/models/bart/large/test_bart_large.py | 4 +++- tests/jax/models/distilbert/test_distilbert.py | 4 +++- tests/jax/models/gpt2/test_gpt2.py | 4 +++- .../jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py | 2 +- tests/jax/models/roberta/base/test_roberta_base.py | 4 +++- tests/jax/models/roberta/large/test_roberta_large.py | 4 +++- tests/jax/models/squeezebert/test_squeezebert.py | 2 +- 13 files changed, 32 insertions(+), 16 deletions(-) diff --git a/tests/jax/graphs/test_softmax.py b/tests/jax/graphs/test_softmax.py index e81be2d9..5ec93aa8 100644 --- a/tests/jax/graphs/test_softmax.py +++ b/tests/jax/graphs/test_softmax.py @@ -16,9 +16,6 @@ [(64, 64), 1], ], ) -@pytest.mark.xfail( - reason="Atol comparison failed. Calculated: atol=inf. Required: atol=0.16" -) def test_softmax(x_shape: tuple, axis: int): def softmax(x: jax.Array) -> jax.Array: return jax.nn.softmax(x, axis=axis) diff --git a/tests/jax/models/albert/v2/base/test_albert_base.py b/tests/jax/models/albert/v2/base/test_albert_base.py index 8c9388ce..01d1eaff 100644 --- a/tests/jax/models/albert/v2/base/test_albert_base.py +++ b/tests/jax/models/albert/v2/base/test_albert_base.py @@ -7,7 +7,6 @@ from ..tester import AlbertV2Tester - MODEL_PATH = "albert/albert-base-v2" @@ -27,7 +26,9 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_albert_v2_base_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/large/test_albert_large.py b/tests/jax/models/albert/v2/large/test_albert_large.py index f47b8260..e5619417 100644 --- a/tests/jax/models/albert/v2/large/test_albert_large.py +++ b/tests/jax/models/albert/v2/large/test_albert_large.py @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_albert_v2_large_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py index 3132af8c..aee6302a 100644 --- a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py +++ b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_albert_v2_xlarge_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py index d695a77f..08fa0e6c 100644 --- a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py +++ b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py @@ -26,7 +26,9 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_albert_v2_xxlarge_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/bart/base/test_bart_base.py b/tests/jax/models/bart/base/test_bart_base.py index d6989251..085d32e7 100644 --- a/tests/jax/models/bart/base/test_bart_base.py +++ b/tests/jax/models/bart/base/test_bart_base.py @@ -26,7 +26,9 @@ def training_tester() -> FlaxBartForCausalLMTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_bart_base_inference( inference_tester: FlaxBartForCausalLMTester, ): diff --git a/tests/jax/models/bart/large/test_bart_large.py b/tests/jax/models/bart/large/test_bart_large.py index 142641dc..2941deb4 100644 --- a/tests/jax/models/bart/large/test_bart_large.py +++ b/tests/jax/models/bart/large/test_bart_large.py @@ -26,7 +26,9 @@ def training_tester() -> FlaxBartForCausalLMTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)" +) def test_flax_bart_large_inference( inference_tester: FlaxBartForCausalLMTester, ): diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index bf9202d8..3ba560dd 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -53,7 +53,9 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_distilbert_inference( inference_tester: FlaxDistilBertForMaskedLMTester, ): diff --git a/tests/jax/models/gpt2/test_gpt2.py b/tests/jax/models/gpt2/test_gpt2.py index 988de9d9..13cb8c4c 100644 --- a/tests/jax/models/gpt2/test_gpt2.py +++ b/tests/jax/models/gpt2/test_gpt2.py @@ -51,7 +51,9 @@ def training_tester() -> GPT2Tester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_gpt2_inference( inference_tester: GPT2Tester, ): diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index 877ecff2..ee3b6b67 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -26,7 +26,7 @@ def training_tester() -> LLamaTester: # ----- Tests ----- -# @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +# @pytest.mark.xfail(reason="failed to legalize operation 'ttir.gather'") @pytest.mark.skip( reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" ) diff --git a/tests/jax/models/roberta/base/test_roberta_base.py b/tests/jax/models/roberta/base/test_roberta_base.py index 9969204a..90311923 100644 --- a/tests/jax/models/roberta/base/test_roberta_base.py +++ b/tests/jax/models/roberta/base/test_roberta_base.py @@ -26,7 +26,9 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_roberta_base_inference( inference_tester: FlaxRobertaForMaskedLMTester, ): diff --git a/tests/jax/models/roberta/large/test_roberta_large.py b/tests/jax/models/roberta/large/test_roberta_large.py index e1f94d03..8ae54b36 100644 --- a/tests/jax/models/roberta/large/test_roberta_large.py +++ b/tests/jax/models/roberta/large/test_roberta_large.py @@ -26,7 +26,9 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") +@pytest.mark.xfail( + reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)" +) def test_flax_roberta_large_inference( inference_tester: FlaxRobertaForMaskedLMTester, ): diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py index 2362066b..09efdb1a 100644 --- a/tests/jax/models/squeezebert/test_squeezebert.py +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -72,7 +72,7 @@ def training_tester() -> SqueezeBertTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.xfail(reason="failed to legalize operation 'ttir.convolution'") def test_squeezebert_inference( inference_tester: SqueezeBertTester, ):