From ad94f9352814055730d432c5f630c0e2d1883cdc Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:08:56 +0000 Subject: [PATCH 1/2] Add BLOOM tests --- tests/jax/models/bloom/__init__.py | 0 tests/jax/models/bloom/bloom_1b1/__init__.py | 0 tests/jax/models/bloom/bloom_1b1/test_1b1.py | 50 +++++++++++++++++++ tests/jax/models/bloom/bloom_1b7/__init__.py | 0 tests/jax/models/bloom/bloom_1b7/test_1b7.py | 50 +++++++++++++++++++ tests/jax/models/bloom/bloom_3b/__init__.py | 0 tests/jax/models/bloom/bloom_3b/test_3b.py | 50 +++++++++++++++++++ tests/jax/models/bloom/bloom_560m/__init__.py | 0 .../jax/models/bloom/bloom_560m/test_560m.py | 50 +++++++++++++++++++ tests/jax/models/bloom/bloom_7b/__init__.py | 0 tests/jax/models/bloom/bloom_7b/test_7b.py | 49 ++++++++++++++++++ tests/jax/models/bloom/tester.py | 41 +++++++++++++++ 12 files changed, 290 insertions(+) create mode 100644 tests/jax/models/bloom/__init__.py create mode 100644 tests/jax/models/bloom/bloom_1b1/__init__.py create mode 100644 tests/jax/models/bloom/bloom_1b1/test_1b1.py create mode 100644 tests/jax/models/bloom/bloom_1b7/__init__.py create mode 100644 tests/jax/models/bloom/bloom_1b7/test_1b7.py create mode 100644 tests/jax/models/bloom/bloom_3b/__init__.py create mode 100644 tests/jax/models/bloom/bloom_3b/test_3b.py create mode 100644 tests/jax/models/bloom/bloom_560m/__init__.py create mode 100644 tests/jax/models/bloom/bloom_560m/test_560m.py create mode 100644 tests/jax/models/bloom/bloom_7b/__init__.py create mode 100644 tests/jax/models/bloom/bloom_7b/test_7b.py create mode 100644 tests/jax/models/bloom/tester.py diff --git a/tests/jax/models/bloom/__init__.py b/tests/jax/models/bloom/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_1b1/__init__.py b/tests/jax/models/bloom/bloom_1b1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_1b1/test_1b1.py b/tests/jax/models/bloom/bloom_1b1/test_1b1.py new file mode 100644 index 00000000..7e416875 --- /dev/null +++ b/tests/jax/models/bloom/bloom_1b1/test_1b1.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import BloomTester + +MODEL_PATH = "bigscience/bloom-1b1" +MODEL_NAME = "bloom-1.1b" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> BloomTester: + return BloomTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> BloomTester: + return BloomTester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="Unsupported data type") # segfault +def test_bloom_1b1_inference( + inference_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_bloom_1b1_training( + training_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bloom/bloom_1b7/__init__.py b/tests/jax/models/bloom/bloom_1b7/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_1b7/test_1b7.py b/tests/jax/models/bloom/bloom_1b7/test_1b7.py new file mode 100644 index 00000000..7b36b8b0 --- /dev/null +++ b/tests/jax/models/bloom/bloom_1b7/test_1b7.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import BloomTester + +MODEL_PATH = "bigscience/bloom-1b7" +MODEL_NAME = "bloom-1.7b" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> BloomTester: + return BloomTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> BloomTester: + return BloomTester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="Unsupported data type") # segfault +def test_bloom_1b7_inference( + inference_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_bloom_1b7_training( + training_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bloom/bloom_3b/__init__.py b/tests/jax/models/bloom/bloom_3b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_3b/test_3b.py b/tests/jax/models/bloom/bloom_3b/test_3b.py new file mode 100644 index 00000000..f56ea0cb --- /dev/null +++ b/tests/jax/models/bloom/bloom_3b/test_3b.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import BloomTester + +MODEL_PATH = "bigscience/bloom-3b" +MODEL_NAME = "bloom-3b" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> BloomTester: + return BloomTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> BloomTester: + return BloomTester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="Unsupported data type") # segfault +def test_bloom_3b_inference( + inference_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_bloom_3b_training( + training_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bloom/bloom_560m/__init__.py b/tests/jax/models/bloom/bloom_560m/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_560m/test_560m.py b/tests/jax/models/bloom/bloom_560m/test_560m.py new file mode 100644 index 00000000..75ed7bbf --- /dev/null +++ b/tests/jax/models/bloom/bloom_560m/test_560m.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import BloomTester + +MODEL_PATH = "bigscience/bloom-560m" +MODEL_NAME = "bloom-560m" + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> BloomTester: + return BloomTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> BloomTester: + return BloomTester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="Unsupported data type") # segfault +def test_bloom_560m_inference( + inference_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_bloom_560m_training( + training_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bloom/bloom_7b/__init__.py b/tests/jax/models/bloom/bloom_7b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/bloom/bloom_7b/test_7b.py b/tests/jax/models/bloom/bloom_7b/test_7b.py new file mode 100644 index 00000000..dc822500 --- /dev/null +++ b/tests/jax/models/bloom/bloom_7b/test_7b.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +import pytest +from infra import ModelTester, RunMode +from utils import record_model_test_properties + +from ..tester import BloomTester + +MODEL_PATH = "bigscience/bloom-7b1" +MODEL_NAME = "bloom-7b" + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> BloomTester: + return BloomTester(MODEL_PATH) + + +@pytest.fixture +def training_tester() -> BloomTester: + return BloomTester(ModelTester, run_mode=RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip(reason="Unsupported data type") # segfault +def test_bloom_7b_inference( + inference_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_bloom_7b_training( + training_tester: BloomTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + + training_tester.test() diff --git a/tests/jax/models/bloom/tester.py b/tests/jax/models/bloom/tester.py new file mode 100644 index 00000000..fab4507e --- /dev/null +++ b/tests/jax/models/bloom/tester.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Sequence + +import jax +from flax import linen as nn +from infra import ComparisonConfig, ModelTester, RunMode +from transformers import AutoTokenizer, FlaxBloomForCausalLM + + +class BloomTester(ModelTester): + """Tester for Bloom models.""" + + def __init__( + self, + model_name: str, + comparison_config: ComparisonConfig = ComparisonConfig(), + run_mode: RunMode = RunMode.INFERENCE, + ) -> None: + self._model_name = model_name + super().__init__(comparison_config, run_mode) + + # @override + def _get_model(self) -> nn.Module: + return FlaxBloomForCausalLM.from_pretrained(self._model_name, from_pt=True) + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + tokenizer = AutoTokenizer.from_pretrained(self._model_name) + inputs = tokenizer("Hello", return_tensors="np") + return inputs["input_ids"] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + assert hasattr(self._model, "params") + return { + "params": self._model.params, + "input_ids": self._get_input_activations(), + } From 13da4739333644329454a6d99212a5941ba801a7 Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Wed, 5 Feb 2025 13:24:14 +0000 Subject: [PATCH 2/2] Update skip reason --- tests/jax/models/bloom/bloom_1b1/test_1b1.py | 9 ++++++--- tests/jax/models/bloom/bloom_1b7/test_1b7.py | 4 ++-- tests/jax/models/bloom/bloom_3b/test_3b.py | 4 ++-- tests/jax/models/bloom/bloom_560m/test_560m.py | 4 ++-- tests/jax/models/bloom/bloom_7b/test_7b.py | 4 ++-- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/jax/models/bloom/bloom_1b1/test_1b1.py b/tests/jax/models/bloom/bloom_1b1/test_1b1.py index 7e416875..bae34635 100644 --- a/tests/jax/models/bloom/bloom_1b1/test_1b1.py +++ b/tests/jax/models/bloom/bloom_1b1/test_1b1.py @@ -6,7 +6,7 @@ import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties +from utils import compile_fail, record_model_test_properties from ..tester import BloomTester @@ -29,8 +29,11 @@ def training_tester() -> BloomTester: # ----- Tests ----- - -@pytest.mark.skip(reason="Unsupported data type") # segfault +# This is an interesting one. +# The error message seems to happen before the compile even begins +# And then then compile segfaults with no useful information +# It is highly likely that both are caused by the same root cause +@pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault def test_bloom_1b1_inference( inference_tester: BloomTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/bloom/bloom_1b7/test_1b7.py b/tests/jax/models/bloom/bloom_1b7/test_1b7.py index 7b36b8b0..56bee84a 100644 --- a/tests/jax/models/bloom/bloom_1b7/test_1b7.py +++ b/tests/jax/models/bloom/bloom_1b7/test_1b7.py @@ -6,7 +6,7 @@ import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties +from utils import compile_fail, record_model_test_properties from ..tester import BloomTester @@ -30,7 +30,7 @@ def training_tester() -> BloomTester: # ----- Tests ----- -@pytest.mark.skip(reason="Unsupported data type") # segfault +@pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault def test_bloom_1b7_inference( inference_tester: BloomTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/bloom/bloom_3b/test_3b.py b/tests/jax/models/bloom/bloom_3b/test_3b.py index f56ea0cb..808eafcd 100644 --- a/tests/jax/models/bloom/bloom_3b/test_3b.py +++ b/tests/jax/models/bloom/bloom_3b/test_3b.py @@ -6,7 +6,7 @@ import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties +from utils import compile_fail, record_model_test_properties from ..tester import BloomTester @@ -30,7 +30,7 @@ def training_tester() -> BloomTester: # ----- Tests ----- -@pytest.mark.skip(reason="Unsupported data type") # segfault +@pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault def test_bloom_3b_inference( inference_tester: BloomTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/bloom/bloom_560m/test_560m.py b/tests/jax/models/bloom/bloom_560m/test_560m.py index 75ed7bbf..23100de8 100644 --- a/tests/jax/models/bloom/bloom_560m/test_560m.py +++ b/tests/jax/models/bloom/bloom_560m/test_560m.py @@ -6,7 +6,7 @@ import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties +from utils import compile_fail, record_model_test_properties from ..tester import BloomTester @@ -30,7 +30,7 @@ def training_tester() -> BloomTester: # ----- Tests ----- -@pytest.mark.skip(reason="Unsupported data type") # segfault +@pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault def test_bloom_560m_inference( inference_tester: BloomTester, record_tt_xla_property: Callable, diff --git a/tests/jax/models/bloom/bloom_7b/test_7b.py b/tests/jax/models/bloom/bloom_7b/test_7b.py index dc822500..aff6420f 100644 --- a/tests/jax/models/bloom/bloom_7b/test_7b.py +++ b/tests/jax/models/bloom/bloom_7b/test_7b.py @@ -6,7 +6,7 @@ import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties +from utils import compile_fail, record_model_test_properties from ..tester import BloomTester @@ -29,7 +29,7 @@ def training_tester() -> BloomTester: # ----- Tests ----- -@pytest.mark.skip(reason="Unsupported data type") # segfault +@pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault def test_bloom_7b_inference( inference_tester: BloomTester, record_tt_xla_property: Callable,