diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 5f3d5a98..bdf44a0a 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -2,7 +2,22 @@ name: Build and Test on: workflow_dispatch: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string workflow_call: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string + +permissions: + packages: write + checks: write + pull-requests: write jobs: # build-ttxla: @@ -144,6 +159,13 @@ jobs: submodules: recursive lfs: true + - name: Override tt-mlir SHA mlir_override is set + if: ${{ inputs.mlir_override }} + shell: bash + run: | + # Update the CMakeLists.txt file with the new SHA + sed -i "s/set(TT_MLIR_VERSION \".*\")/set(TT_MLIR_VERSION \"${{ inputs.mlir_override }}\")/" third_party/CMakeLists.txt + - name: Set reusable strings id: strings shell: bash @@ -187,6 +209,21 @@ jobs: cmake --build ${{ steps.strings.outputs.build-output-dir }} cmake --install ${{ steps.strings.outputs.build-output-dir }} + - name: Verify tt-mlir SHA override + if: ${{ inputs.mlir_override }} + continue-on-error: true + shell: bash + run: | + cd third_party/tt-mlir + branch_name=$(git rev-parse --abbrev-ref HEAD) + commit_sha=$(git rev-parse HEAD) + commit_title=$(git log -1 --pretty=%s) + echo "Branch name: $branch_name" + echo "Commit SHA: $commit_sha" + echo "Commit title: $commit_title" + echo "::notice::Using tt-mlir: $branch_name, commit: $commit_sha, title: $commit_title" + cd ../.. + - name: Run tests shell: bash run: | @@ -203,11 +240,15 @@ jobs: path: ${{ steps.strings.outputs.test_report_path }} - name: Show Test Report - uses: mikepenz/action-junit-report@v4 + uses: mikepenz/action-junit-report@v5 if: success() || failure() with: report_paths: ${{ steps.strings.outputs.test_report_path }} check_name: TT-XLA Tests + comment: true + updateComment: true + detailed_summary: true + group_suite: true - name: Prepare code coverage report run: | diff --git a/.github/workflows/on-pr.yml b/.github/workflows/on-pr.yml index 34ea5ef6..5c407bbf 100644 --- a/.github/workflows/on-pr.yml +++ b/.github/workflows/on-pr.yml @@ -2,6 +2,11 @@ name: On PR on: workflow_dispatch: + inputs: + mlir_override: + description: 'Git SHA of commit in tenstorrent/tt-mlir' + required: false + type: string pull_request: branches: [ "main" ] @@ -20,3 +25,5 @@ jobs: needs: [pre-commit, spdx] uses: ./.github/workflows/build-and-test.yml secrets: inherit + with: + mlir_override: ${{ inputs.mlir_override }} diff --git a/requirements.txt b/requirements.txt index 38d4414f..781ab540 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,7 @@ lit pybind11 pytest transformers +fsspec +einops +torch +ml_collections diff --git a/tests/jax/graphs/test_activation_functions.py b/tests/jax/graphs/test_activation_functions.py index 7dacb282..f715f5bc 100644 --- a/tests/jax/graphs/test_activation_functions.py +++ b/tests/jax/graphs/test_activation_functions.py @@ -9,9 +9,6 @@ @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -@pytest.mark.skip( - "ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast" -) def test_relu(x_shape: tuple): """Test ReLU activation function.""" diff --git a/tests/jax/graphs/test_linear_transformation.py b/tests/jax/graphs/test_linear_transformation.py index b3c60a58..0e055166 100644 --- a/tests/jax/graphs/test_linear_transformation.py +++ b/tests/jax/graphs/test_linear_transformation.py @@ -8,7 +8,6 @@ from infra import run_graph_test_with_random_inputs -@pytest.mark.skip("Skipped due to https://github.com/tenstorrent/tt-xla/issues/162") @pytest.mark.parametrize( ["x_shape", "y_shape", "bias_shape"], [ diff --git a/tests/jax/graphs/test_simple_regression.py b/tests/jax/graphs/test_simple_regression.py index 4449bdc4..03e54d70 100644 --- a/tests/jax/graphs/test_simple_regression.py +++ b/tests/jax/graphs/test_simple_regression.py @@ -10,7 +10,9 @@ @pytest.mark.parametrize( ["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]] ) -@pytest.mark.skip("failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail( + reason="Atol comparison failed. Calculated: atol=0.850662112236023. Required: atol=0.16" +) def test_simple_regression(weights, bias, X, y): def simple_regression(weights, bias, X, y): def loss(weights, bias, X, y): diff --git a/tests/jax/graphs/test_softmax.py b/tests/jax/graphs/test_softmax.py index 6853d363..e81be2d9 100644 --- a/tests/jax/graphs/test_softmax.py +++ b/tests/jax/graphs/test_softmax.py @@ -16,9 +16,8 @@ [(64, 64), 1], ], ) -@pytest.mark.skip( - "tt-metal assert: Index is out of bounds for the rank. " - "Similar to https://github.com/tenstorrent/tt-xla/issues/12" +@pytest.mark.xfail( + reason="Atol comparison failed. Calculated: atol=inf. Required: atol=0.16" ) def test_softmax(x_shape: tuple, axis: int): def softmax(x: jax.Array) -> jax.Array: diff --git a/tests/jax/models/albert/v2/base/test_albert_base.py b/tests/jax/models/albert/v2/base/test_albert_base.py index 5c54e308..8c9388ce 100644 --- a/tests/jax/models/albert/v2/base/test_albert_base.py +++ b/tests/jax/models/albert/v2/base/test_albert_base.py @@ -27,7 +27,7 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_base_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/large/test_albert_large.py b/tests/jax/models/albert/v2/large/test_albert_large.py index ff59a8be..f47b8260 100644 --- a/tests/jax/models/albert/v2/large/test_albert_large.py +++ b/tests/jax/models/albert/v2/large/test_albert_large.py @@ -26,7 +26,7 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_large_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py index 93346267..3132af8c 100644 --- a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py +++ b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py @@ -26,7 +26,7 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_xlarge_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py index cce0ef8f..d695a77f 100644 --- a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py +++ b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py @@ -26,7 +26,7 @@ def training_tester() -> AlbertV2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_xxlarge_inference( inference_tester: AlbertV2Tester, ): diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index 06d3b785..bf9202d8 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -53,7 +53,7 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_distilbert_inference( inference_tester: FlaxDistilBertForMaskedLMTester, ): diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py index 4210afb7..adb7adb2 100644 --- a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -78,7 +78,6 @@ def training_tester() -> ExampleModelMixedArgsAndKwargsTester: # ----- Tests ----- -@pytest.mark.skip("Skipped due to https://github.com/tenstorrent/tt-xla/issues/162") def test_example_model_inference( inference_tester: ExampleModelMixedArgsAndKwargsTester, ): diff --git a/tests/jax/models/example_model/only_args/test_example_model_only_args.py b/tests/jax/models/example_model/only_args/test_example_model_only_args.py index 069769ea..47ef50e5 100644 --- a/tests/jax/models/example_model/only_args/test_example_model_only_args.py +++ b/tests/jax/models/example_model/only_args/test_example_model_only_args.py @@ -73,7 +73,6 @@ def training_tester() -> ExampleModelOnlyArgsTester: # ----- Tests ----- -@pytest.mark.skip("Skipped due to https://github.com/tenstorrent/tt-xla/issues/162") def test_example_model_inference(inference_tester: ExampleModelOnlyArgsTester): inference_tester.test() diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py index 81abbb31..e74ad2c5 100644 --- a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -73,7 +73,6 @@ def training_tester() -> ExampleModelOnlyKwargsTester: # ----- Tests ----- -@pytest.mark.skip("Skipped due to https://github.com/tenstorrent/tt-xla/issues/162") def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester): inference_tester.test() diff --git a/tests/jax/models/gpt2/test_gpt2.py b/tests/jax/models/gpt2/test_gpt2.py index efb9a344..40d6c2bb 100644 --- a/tests/jax/models/gpt2/test_gpt2.py +++ b/tests/jax/models/gpt2/test_gpt2.py @@ -51,7 +51,7 @@ def training_tester() -> GPT2Tester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_gp2_inference( inference_tester: GPT2Tester, ): diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index 18688ed0..877ecff2 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from flax import linen as nn from infra import RunMode from ..tester import LLamaTester @@ -27,7 +26,10 @@ def training_tester() -> LLamaTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +# @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.skip( + reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" +) def test_openllama3b_inference( inference_tester: LLamaTester, ): diff --git a/tests/jax/models/llama/tester.py b/tests/jax/models/llama/tester.py index e915e0db..82b43330 100644 --- a/tests/jax/models/llama/tester.py +++ b/tests/jax/models/llama/tester.py @@ -21,8 +21,8 @@ def __init__( comparison_config: ComparisonConfig = ComparisonConfig(), run_mode: RunMode = RunMode.INFERENCE, ) -> None: - super().__init__(comparison_config, run_mode) self._model_name = model_name + super().__init__(comparison_config, run_mode) # @override def _get_model(self) -> nn.Module: diff --git a/tests/jax/models/mlpmixer/__init__.py b/tests/jax/models/mlpmixer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py new file mode 100644 index 00000000..03679e4b --- /dev/null +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# This file incorporates work covered by the following copyright and permission +# notice: +# SPDX-FileCopyrightText: Copyright 2024 Google LLC. +# SPDX-License-Identifier: Apache-2.0 + +# This code is based on google-research/vision_transformer + +from typing import Any, Optional + +import einops +import flax.linen as nn +import jax.numpy as jnp +import jax + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + + tokens_mlp_dim: int + channels_mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y + + y = nn.LayerNorm()(x) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + y = x + y + + return y + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + + patches: Any + num_classes: int + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + x = nn.Conv( + self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" + )( + inputs + ) # Patch embedding + x = einops.rearrange(x, "n h w c -> n (h w) c") + + for _ in range(self.num_blocks): + x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) + + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = jnp.mean(x, axis=1) + + if self.num_classes: + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head" + )(x) + + return x diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py new file mode 100644 index 00000000..469e0688 --- /dev/null +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Sequence + +import flax.traverse_util +import fsspec +import jax +import ml_collections +import numpy +import pytest +from flax import linen as nn +from infra import ModelTester, RunMode + +from .model_implementation import MlpMixer + +# Hyperparameters for Mixer-B/16 +patch_size = 16 +num_classes = 21843 +num_blocks = 12 +hidden_dim = 768 +token_mlp_dim = 384 +channel_mlp_dim = 3072 + + +class MlpMixerTester(ModelTester): + """Tester for MlpMixer model.""" + + # @override + def _get_model(self) -> nn.Module: + patch = ml_collections.ConfigDict({"size": (patch_size, patch_size)}) + return MlpMixer( + patches=patch, + num_classes=num_classes, + num_blocks=num_blocks, + hidden_dim=hidden_dim, + tokens_mlp_dim=token_mlp_dim, + channels_mlp_dim=channel_mlp_dim, + ) + + @staticmethod + def _retrieve_pretrained_weights() -> Dict: + # TODO(stefan): Discuss how weights should be handled org wide + link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" + with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + weights = numpy.load(f, encoding="bytes") + state_dict = {k: v for k, v in weights.items()} + pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") + return {"params": pytree} + + # @override + def _get_forward_method_name(self) -> str: + return "apply" + + # @override + def _get_input_activations(self) -> jax.Array: + key = jax.random.PRNGKey(42) + random_image = jax.random.normal(key, (1, 224, 224, 3)) + return random_image + + # @override + def _get_forward_method_args(self) -> Sequence[Any]: + ins = self._get_input_activations() + weights = self._retrieve_pretrained_weights() + + # Alternatively, weights could be randomly initialized like this: + # weights = self._model.init(jax.random.PRNGKey(42), ins) + + # JAX frameworks have a convention of passing weights as the first argument + return [weights, ins] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> MlpMixerTester: + return MlpMixerTester() + + +@pytest.fixture +def training_tester() -> MlpMixerTester: + return MlpMixerTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip( + reason=( + "Statically allocated circular buffers in program 16 clash with L1 buffers " + "on core range [(x=0,y=0) - (x=6,y=0)]. L1 buffer allocated at 475136 and " + "static circular buffer region ends at 951136 " + "(https://github.com/tenstorrent/tt-xla/issues/187)" + ) +) # segfault +def test_mlpmixer(inference_tester: MlpMixerTester): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mlpmixer_training(training_tester: MlpMixerTester): + training_tester.test() diff --git a/tests/jax/models/mnist/cnn/test_mnist_cnn.py b/tests/jax/models/mnist/cnn/test_mnist_cnn.py index 1a6f8fcf..88004ec1 100644 --- a/tests/jax/models/mnist/cnn/test_mnist_cnn.py +++ b/tests/jax/models/mnist/cnn/test_mnist_cnn.py @@ -67,7 +67,7 @@ def training_tester() -> MNISTCNNTester: @pytest.mark.skip( reason='void mlir::OperationConverter::finalize(mlir::ConversionPatternRewriter &): Assertion `newValue && "replacement value not found"\' failed.' -) +) # This is a segfault, marking it as xfail would bring down the whole test suite def test_mnist_inference( inference_tester: MNISTCNNTester, ): diff --git a/tests/jax/models/roberta/test_roberta.py b/tests/jax/models/roberta/test_roberta.py index f883bf86..32542cd5 100644 --- a/tests/jax/models/roberta/test_roberta.py +++ b/tests/jax/models/roberta/test_roberta.py @@ -57,7 +57,7 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: # ----- Tests ----- -@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") def test_roberta_inference( inference_tester: FlaxRobertaForMaskedLMTester, ): diff --git a/tests/jax/models/squeezebert/model_implementation.py b/tests/jax/models/squeezebert/model_implementation.py new file mode 100644 index 00000000..9ff83e5f --- /dev/null +++ b/tests/jax/models/squeezebert/model_implementation.py @@ -0,0 +1,363 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Any, Dict, Tuple + +import einops +import flax.traverse_util +import jax +import jax.numpy as jnp +from flax import linen as nn +from transformers import SqueezeBertConfig + + +class SqueezeBertEmbedding(nn.Module): + """Embedding layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.word_embedding = nn.Embed( + self.config.vocab_size, self.config.embedding_size + ) + self.position_embedding = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + ) + self.token_type_embedding = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + ) + + self.layernorm = nn.LayerNorm() + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__( + self, + input_ids: jax.Array, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + deterministic: bool = False, + ) -> jax.Array: + if position_ids is None: + position_ids = jax.numpy.arange(input_ids.shape[1]) + if token_type_ids is None: + token_type_ids = jax.numpy.zeros_like(input_ids) + + word_embeddings = self.word_embedding(input_ids) + position_embeddings = self.position_embedding(position_ids) + token_type_embeddings = self.token_type_embedding(token_type_ids) + + embeddings = word_embeddings + position_embeddings + token_type_embeddings + embeddings = self.layernorm(embeddings) + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class SqueezeBertSelfAttention(nn.Module): + """Self-attention layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.query = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.q_groups, + ) + self.key = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.k_groups, + ) + self.value = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.v_groups, + ) + self.output = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.post_attention_groups, + ) + + self.attn_dropout = nn.Dropout(rate=self.config.attention_probs_dropout_prob) + self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layernorm = nn.LayerNorm() + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + head_dim = self.config.hidden_size // self.config.num_attention_heads + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + query = einops.rearrange( + query, + "b s (H d) -> b s H d", # batch sequence Heads dim_head + H=self.config.num_attention_heads, + d=head_dim, + ) + key = einops.rearrange( + key, + "b s (H d) -> b s H d", + H=self.config.num_attention_heads, + d=head_dim, + ) + value = einops.rearrange( + value, + "b s (H d) -> b s H d", + H=self.config.num_attention_heads, + d=head_dim, + ) + + attention_scores = jnp.einsum("B s H d ,B S H d -> B H s S", query, key) + attention_scores = attention_scores / jnp.sqrt(head_dim) + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = nn.activation.softmax(attention_scores, axis=-1) + attention_probs = self.attn_dropout( + attention_probs, deterministic=deterministic + ) + + context = jnp.einsum("B H s S, B S H d -> B s H d", attention_probs, value) + context = einops.rearrange(context, "b s H d -> b s (H d)") + + output = self.output(context) + output = self.resid_dropout(output, deterministic=deterministic) + output = hidden_states + output + output = self.layernorm(output) + return output + + +class SqueezeBertMLP(nn.Module): + """MLP layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.w1 = nn.Conv( + features=self.config.intermediate_size, + kernel_size=(1,), + feature_group_count=self.config.intermediate_groups, + ) + if self.config.hidden_act == "gelu": + self.act = nn.gelu + else: + raise ValueError( + f"Activation function {self.config.hidden_act} not supported." + ) + self.w2 = nn.Conv( + features=self.config.hidden_size, + kernel_size=(1,), + feature_group_count=self.config.output_groups, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.layernorm = nn.LayerNorm() + + def __call__( + self, hidden_states: jax.Array, deterministic: bool = False + ) -> jax.Array: + x = self.w1(hidden_states) + x = self.act(x) + x = self.w2(x) + x = self.dropout(x, deterministic=deterministic) + output = hidden_states + x + output = self.layernorm(output) + return output + + +class SqueezeBertLayer(nn.Module): + """Layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.attention = SqueezeBertSelfAttention(self.config) + self.mlp = SqueezeBertMLP(self.config) + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + attention_output = self.attention( + hidden_states, attention_mask, deterministic=deterministic + ) + output = self.mlp(attention_output, deterministic=deterministic) + return output + + +class SqueezeBertEncoder(nn.Module): + """Encoder for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.layers = [ + SqueezeBertLayer(self.config) for _ in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states: jax.Array, + attention_mask: jax.Array, + deterministic: bool = False, + ) -> jax.Array: + for layer in self.layers: + hidden_states = layer( + hidden_states, attention_mask, deterministic=deterministic + ) + return hidden_states + + +class SqueezeBertPooler(nn.Module): + """Pooler layer for SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size) + self.activation = nn.tanh + + def __call__(self, hidden_states: jax.Array) -> jax.Array: + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class SqueezeBertModel(nn.Module): + """SqueezeBERT model.""" + + config: SqueezeBertConfig + + def setup(self): + self.embeddings = SqueezeBertEmbedding(self.config) + self.encoder = SqueezeBertEncoder(self.config) + self.pooler = SqueezeBertPooler(self.config) + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + *, + train: bool, + ) -> Tuple[jax.Array, jax.Array]: + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + embeddings = self.embeddings( + input_ids, token_type_ids, position_ids, deterministic=not train + ) + encoder_output = self.encoder( + embeddings, attention_mask, deterministic=not train + ) + pooled_output = self.pooler(encoder_output) + return encoder_output, pooled_output + + +class SqueezeBertForMaskedLM(nn.Module): + """SqueezeBERT model with masked language modeling head.""" + + config: SqueezeBertConfig + + def setup(self): + self.squeezebert = SqueezeBertModel(self.config) + self.transform_dense = nn.Dense(self.config.hidden_size) + + if self.config.hidden_act == "gelu": + self.transform_act = nn.gelu + else: + raise ValueError( + f"Activation function {self.config.hidden_act} not supported." + ) + + self.transform_layernorm = nn.LayerNorm() + + self.decoder = nn.Dense(self.config.vocab_size) + # TODO(stefan): Figure out if SqueezeBERT uses tied weights for embeddings and output layer + # that is only relevant for training + + def __call__( + self, + input_ids: jax.Array, + attention_mask: jax.Array = None, + token_type_ids: jax.Array = None, + position_ids: jax.Array = None, + *, + train: bool, + ) -> jax.Array: + hidden_states, _ = self.squeezebert( + input_ids, attention_mask, token_type_ids, position_ids, train=train + ) + hidden_states = self.transform_dense(hidden_states) + hidden_states = self.transform_act(hidden_states) + hidden_states = self.transform_layernorm(hidden_states) + + prediction_scores = self.decoder(hidden_states) + return prediction_scores + + @staticmethod + def init_from_pytorch_statedict(state_dict: Dict[str, Any]) -> Dict[str, Any]: + # Key substitutions for remapping huggingface checkpoints to this implementation + PATTERNS = [ + ("transformer.", "squeezebert."), + ("LayerNorm", "layernorm"), + ("layernorm.weight", "layernorm.scale"), + ("_embeddings.weight", "_embedding.embedding"), + ("encoder.layers.", "encoder.layers_"), + ("attention.query.weight", "attention.query.kernel"), + ("attention.key.weight", "attention.key.kernel"), + ("attention.value.weight", "attention.value.kernel"), + ("post_attention.conv1d.weight", "attention.output.kernel"), + ("post_attention.conv1d.bias", "attention.output.bias"), + ("post_attention.layernorm", "attention.layernorm"), + ("intermediate.conv1d.weight", "mlp.w1.kernel"), + ("intermediate.conv1d.bias", "mlp.w1.bias"), + ("output.conv1d.weight", "mlp.w2.kernel"), + ("output.conv1d.bias", "mlp.w2.bias"), + ("output.layernorm", "mlp.layernorm"), + ("pooler.dense.weight", "pooler.dense.kernel"), + ("cls.predictions.transform.dense.weight", "transform_dense.kernel"), + ("cls.predictions.transform.dense.bias", "transform_dense.bias"), + ("cls.predictions.transform.layernorm", "transform_layernorm"), + ("cls.predictions.decoder.weight", "decoder.kernel"), + ("cls.predictions.bias", "decoder.bias"), + ] + + def is_banned_key(key: str) -> bool: + return "seq_relationship" in key + + def rewrite_key(key: str) -> str: + for pattern in PATTERNS: + key = re.sub(pattern[0], pattern[1], key) + return key + + def process_value(k: str, v) -> jnp.ndarray: + if "kernel" in k: + if len(v.shape) == 2: + return jnp.transpose(v) + if len(v.shape) == 3: + return jnp.transpose(v, (2, 1, 0)) + return v + + for k, v in state_dict.items(): + # Inplace conversion might lower peak memory usage + state_dict[k] = jnp.array(v) + + state_dict = { + rewrite_key(k): v for k, v in state_dict.items() if not is_banned_key(k) + } + state_dict = {k: process_value(k, v) for k, v in state_dict.items()} + state_dict = flax.traverse_util.unflatten_dict(state_dict, sep=".") + return {"params": state_dict} diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py new file mode 100644 index 00000000..fe98f0cc --- /dev/null +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +import pytest +import torch +from flax import linen as nn +from huggingface_hub import hf_hub_download +from infra import ModelTester, RunMode +from model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM +from transformers import AutoTokenizer + +MODEL_PATH = "squeezebert/squeezebert-uncased" + +# ----- Tester ----- + + +class SqueezeBertTester(ModelTester): + """Tester for SqueezeBERT model on a masked language modeling task""" + + # @override + def _get_model(self) -> nn.Module: + config = SqueezeBertConfig.from_pretrained(MODEL_PATH) + return SqueezeBertForMaskedLM(config) + + # @override + def _get_forward_method_name(self): + return "apply" + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + inputs = tokenizer("The [MASK] barked at me", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + model_file = hf_hub_download( + repo_id="squeezebert/squeezebert-uncased", filename="pytorch_model.bin" + ) + state_dict = torch.load(model_file, weights_only=True) + + params = self._model.init_from_pytorch_statedict(state_dict) + + return { + "variables": params, # JAX frameworks have a convention of passing weights as the first argument + "input_ids": self._get_input_activations(), + "train": False, + } + + # @override + def _get_static_argnames(self): + return ["train"] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> SqueezeBertTester: + return SqueezeBertTester() + + +@pytest.fixture +def training_tester() -> SqueezeBertTester: + return SqueezeBertTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") +def test_flax_distilbert_inference( + inference_tester: SqueezeBertTester, +): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_flax_distilbert_training( + training_tester: SqueezeBertTester, +): + training_tester.test() diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index f431d061..ffda703d 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -8,9 +8,8 @@ @pytest.mark.parametrize("input_shapes", [[(2, 1)]]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" +@pytest.mark.xfail( + reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16" ) def test_broadcast_in_dim(input_shapes): def broadcast(a): diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index 02657e35..3587a4d5 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -8,10 +8,6 @@ @pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" -) def test_constant_zeros(shape: tuple): def module_constant_zeros(): return jnp.zeros(shape) @@ -20,10 +16,6 @@ def module_constant_zeros(): @pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -@pytest.mark.skip( - "error: type of return operand 0 doesn't match function result type in " - "function @main" -) def test_constant_ones(shape: tuple): def module_constant_ones(): return jnp.ones(shape) @@ -31,7 +23,7 @@ def module_constant_ones(): run_op_test(module_constant_ones, []) -@pytest.mark.skip("Fails due to: error: failed to legalize operation 'ttir.constant'") +@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'") def test_constant_multi_value(): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) diff --git a/tests/jax/ops/test_convolution.py b/tests/jax/ops/test_convolution.py index 6d348954..9e2143c9 100644 --- a/tests/jax/ops/test_convolution.py +++ b/tests/jax/ops/test_convolution.py @@ -78,10 +78,7 @@ def conv1d(img, weights): (1, 256, 256, 14, 14, 3, 3, 1, 1, 1), (1, 1024, 256, 14, 14, 1, 1, 1, 1, 0), (1, 256, 1024, 14, 14, 1, 1, 1, 1, 0), - pytest.param( # TODO This passed in old infra. Investigate. - *(1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), - marks=pytest.mark.skip(reason="Segmentation fault"), - ), + (1, 2048, 1024, 14, 14, 1, 1, 2, 2, 0), (1, 512, 1024, 14, 14, 1, 1, 2, 2, 0), (1, 512, 512, 7, 7, 3, 3, 1, 1, 1), (1, 2048, 512, 7, 7, 1, 1, 1, 1, 0),