From ec63a694d4793cef239ec326af963884233bb9c4 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Thu, 23 Jan 2025 13:02:27 +0000 Subject: [PATCH] Marked all model tests. Added some more op tests. --- tests/__init__.py | 0 tests/conftest.py | 9 ++++- .../models/albert/v2/base/test_albert_base.py | 11 +++++- .../albert/v2/large/test_albert_large.py | 10 +++++ .../albert/v2/xlarge/test_albert_xlarge.py | 10 +++++ .../albert/v2/xxlarge/test_albert_xxlarge.py | 10 +++++ .../jax/models/distilbert/test_distilbert.py | 10 ++++- tests/jax/models/gpt2/test_gpt2.py | 10 ++++- .../openllama_3b_v2/test_openllama_3b_v2.py | 10 +++++ tests/jax/models/mlpmixer/test_mlpmixer.py | 17 ++++++-- tests/jax/models/mnist/cnn/__init__.py | 0 tests/jax/models/mnist/cnn/test_mnist_cnn.py | 11 +++++- tests/jax/models/roberta/test_roberta.py | 10 ++++- .../models/squeezebert/test_squeezebert.py | 10 ++++- tests/jax/ops/test_abs.py | 8 ++-- tests/jax/ops/test_add.py | 8 ++-- tests/jax/ops/test_broadcast_in_dim.py | 8 ++-- tests/jax/ops/test_cbrt.py | 9 ++++- tests/jax/ops/test_compare.py | 1 + tests/jax/ops/test_concatenate.py | 11 +++++- tests/jax/ops/test_constant.py | 10 +++-- tests/jax/ops/test_convert.py | 11 +++++- tests/jax/ops/test_convolution.py | 23 ++++++++++- tests/utils.py | 39 +++++++++++++++++++ 24 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/jax/models/mnist/cnn/__init__.py create mode 100644 tests/utils.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 922602b5..5adac87d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ class RecordProperties(Enum): START_TIMESTAMP = "start_timestamp" # Timestamp of test end. END_TIMESTAMP = "end_timestamp" + # Test type: op test, graph test, model test. + TEST_TYPE = "test_type" # Frontend or framework used to run the test. FRONTEND = "frontend" # Kind of operation. e.g. eltwise. @@ -68,8 +70,11 @@ def test_model(fixture1, fixture2, ..., record_tt_xla_property): @pytest.fixture(scope="function", autouse=True) def record_tt_xla_property(record_property: Callable): """ - Autouse fixture that automatically records a property named 'frontend' with the - value 'tt-forge-fe' for each test function. + Autouse fixture that automatically records some test properties for each test + function. + + It also yields back callable which can be explicitly used in tests to record + additional properties. Example: diff --git a/tests/jax/models/albert/v2/base/test_albert_base.py b/tests/jax/models/albert/v2/base/test_albert_base.py index 8c9388ce..91a1edb4 100644 --- a/tests/jax/models/albert/v2/base/test_albert_base.py +++ b/tests/jax/models/albert/v2/base/test_albert_base.py @@ -2,13 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties from ..tester import AlbertV2Tester - MODEL_PATH = "albert/albert-base-v2" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Fixtures ----- @@ -30,12 +33,18 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_base_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_base_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/large/test_albert_large.py b/tests/jax/models/albert/v2/large/test_albert_large.py index f47b8260..039cc3d8 100644 --- a/tests/jax/models/albert/v2/large/test_albert_large.py +++ b/tests/jax/models/albert/v2/large/test_albert_large.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-large-v2" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Fixtures ----- @@ -29,12 +33,18 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_large_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_large_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py index 3132af8c..613df01a 100644 --- a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py +++ b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-xlarge-v2" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Fixtures ----- @@ -29,12 +33,18 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_xlarge_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_xlarge_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py index d695a77f..ed4a9d2e 100644 --- a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py +++ b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties from ..tester import AlbertV2Tester MODEL_PATH = "albert/albert-xxlarge-v2" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Fixtures ----- @@ -29,12 +33,18 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_albert_v2_xxlarge_inference( inference_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_albert_v2_xxlarge_training( training_tester: AlbertV2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index bf9202d8..977b7f14 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -2,15 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Callable, Dict, Sequence import jax import pytest from flax import linen as nn from infra import ModelTester, RunMode from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM +from utils import record_model_test_properties MODEL_PATH = "distilbert/distilbert-base-uncased" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Tester ----- @@ -56,12 +58,18 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_flax_distilbert_inference( inference_tester: FlaxDistilBertForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_distilbert_training( training_tester: FlaxDistilBertForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/gpt2/test_gpt2.py b/tests/jax/models/gpt2/test_gpt2.py index 40d6c2bb..58c6b9c1 100644 --- a/tests/jax/models/gpt2/test_gpt2.py +++ b/tests/jax/models/gpt2/test_gpt2.py @@ -2,15 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Sequence, Dict +from typing import Callable, Dict, Sequence import jax import pytest from flax import linen as nn from infra import ModelTester, RunMode from transformers import AutoTokenizer, FlaxGPT2LMHeadModel +from utils import record_model_test_properties MODEL_PATH = "openai-community/gpt2" +MODEL_NAME = MODEL_PATH.split("/")[1] class GPT2Tester(ModelTester): @@ -54,12 +56,18 @@ def training_tester() -> GPT2Tester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_gp2_inference( inference_tester: GPT2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_gpt2_training( training_tester: GPT2Tester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index 30d01081..68d2f3e3 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import pytest from infra import RunMode +from utils import record_model_test_properties from ..tester import LLamaTester MODEL_PATH = "openlm-research/open_llama_3b_v2" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Fixtures ----- @@ -29,12 +33,18 @@ def training_tester() -> LLamaTester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") def test_openllama3b_inference( inference_tester: LLamaTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_openllama3b_training( training_tester: LLamaTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index fd8c1fa9..9d7d624f 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Sequence +from typing import Any, Callable, Dict, Sequence import flax.traverse_util import fsspec @@ -12,6 +12,7 @@ import pytest from flax import linen as nn from infra import ModelTester, RunMode +from utils import record_model_test_properties from .model_implementation import MlpMixer @@ -95,10 +96,20 @@ def training_tester() -> MlpMixerTester: @pytest.mark.skip( reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal" ) -def test_mlpmixer(inference_tester: MlpMixerTester): +def test_mlpmixer( + inference_tester: MlpMixerTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MlpMixer.__qualname__) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") -def test_mlpmixer_training(training_tester: MlpMixerTester): +def test_mlpmixer_training( + training_tester: MlpMixerTester, + record_tt_xla_property: Callable, +): + record_model_test_properties(record_tt_xla_property, MlpMixer.__qualname__) + training_tester.test() diff --git a/tests/jax/models/mnist/cnn/__init__.py b/tests/jax/models/mnist/cnn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jax/models/mnist/cnn/test_mnist_cnn.py b/tests/jax/models/mnist/cnn/test_mnist_cnn.py index 88004ec1..a8cfa5af 100644 --- a/tests/jax/models/mnist/cnn/test_mnist_cnn.py +++ b/tests/jax/models/mnist/cnn/test_mnist_cnn.py @@ -2,15 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Callable, Dict, Sequence import jax import jax.numpy as jnp import pytest from flax import linen as nn from infra import ModelTester, RunMode +from utils import record_model_test_properties -from tests.jax.models.mnist.cnn.model_implementation import MNISTCNNModel +from .model_implementation import MNISTCNNModel class MNISTCNNTester(ModelTester): @@ -70,14 +71,20 @@ def training_tester() -> MNISTCNNTester: ) # This is a segfault, marking it as xfail would bring down the whole test suite def test_mnist_inference( inference_tester: MNISTCNNTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MNISTCNNModel.__qualname__) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_mnist_training( training_tester: MNISTCNNTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MNISTCNNModel.__qualname__) + training_tester.test() diff --git a/tests/jax/models/roberta/test_roberta.py b/tests/jax/models/roberta/test_roberta.py index 32542cd5..61797242 100644 --- a/tests/jax/models/roberta/test_roberta.py +++ b/tests/jax/models/roberta/test_roberta.py @@ -2,15 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Callable, Dict, Sequence import jax import pytest from flax import linen as nn from infra import ModelTester, RunMode from transformers import AutoTokenizer, FlaxRobertaForMaskedLM +from utils import record_model_test_properties MODEL_PATH = "FacebookAI/roberta-base" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Tester ----- @@ -60,12 +62,18 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce_window'") def test_roberta_inference( inference_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_roberta_training( training_tester: FlaxRobertaForMaskedLMTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py index fe98f0cc..89a00112 100644 --- a/tests/jax/models/squeezebert/test_squeezebert.py +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Callable, Dict, Sequence import jax import pytest @@ -12,8 +12,10 @@ from infra import ModelTester, RunMode from model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM from transformers import AutoTokenizer +from utils import record_model_test_properties MODEL_PATH = "squeezebert/squeezebert-uncased" +MODEL_NAME = MODEL_PATH.split("/")[1] # ----- Tester ----- @@ -75,12 +77,18 @@ def training_tester() -> SqueezeBertTester: @pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'") def test_flax_distilbert_inference( inference_tester: SqueezeBertTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + inference_tester.test() @pytest.mark.skip(reason="Support for training not implemented") def test_flax_distilbert_training( training_tester: SqueezeBertTester, + record_tt_xla_property: Callable, ): + record_model_test_properties(record_tt_xla_property, MODEL_NAME) + training_tester.test() diff --git a/tests/jax/ops/test_abs.py b/tests/jax/ops/test_abs.py index 5af7cc47..c0fd97f5 100644 --- a/tests/jax/ops/test_abs.py +++ b/tests/jax/ops/test_abs.py @@ -7,8 +7,8 @@ import jax import jax.numpy as jnp import pytest -from conftest import RecordProperties from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) @@ -16,9 +16,9 @@ def test_abs(x_shape: tuple, record_tt_xla_property: Callable): def abs(x: jax.Array) -> jax.Array: return jnp.abs(x) - record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") - record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.abs") - record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.abs") + record_unary_op_test_properties( + record_tt_xla_property, "jax.numpy.abs", "stablehlo.abs" + ) # Test both negative and positive values. run_op_test_with_random_inputs(abs, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_add.py b/tests/jax/ops/test_add.py index bc9de970..17f6d4f0 100644 --- a/tests/jax/ops/test_add.py +++ b/tests/jax/ops/test_add.py @@ -7,8 +7,8 @@ import jax import jax.numpy as jnp import pytest -from conftest import RecordProperties from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -22,8 +22,8 @@ def test_add(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def add(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.add(x, y) - record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") - record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.add") - record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.add") + record_binary_op_test_properties( + record_tt_xla_property, "jax.numpy.add", "stablehlo.add" + ) run_op_test_with_random_inputs(add, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index 59c8964c..e06353d7 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -6,8 +6,8 @@ import jax.numpy as jnp import pytest -from conftest import RecordProperties from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties @pytest.mark.parametrize("input_shapes", [[(2, 1)]]) @@ -18,10 +18,8 @@ def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable) def broadcast(a): return jnp.broadcast_to(a, (2, 4)) - record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") - record_tt_xla_property( - RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.broadcast_to" + record_unary_op_test_properties( + record_tt_xla_property, "jax.numpy.broadcast_to", "stablehlo.broadcast_in_dim" ) - record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.broadcast_in_dim") run_op_test_with_random_inputs(broadcast, input_shapes) diff --git a/tests/jax/ops/test_cbrt.py b/tests/jax/ops/test_cbrt.py index b690034c..8361a86a 100644 --- a/tests/jax/ops/test_cbrt.py +++ b/tests/jax/ops/test_cbrt.py @@ -2,15 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_unary_op_test_properties @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_cbrt(x_shape: tuple): +def test_cbrt(x_shape: tuple, record_tt_xla_property: Callable): def cbrt(x: jax.Array) -> jax.Array: return jnp.cbrt(x) + record_unary_op_test_properties( + record_tt_xla_property, "jax.numpy.cbrt", "stablehlo.cbrt" + ) + run_op_test_with_random_inputs(cbrt, [x_shape]) diff --git a/tests/jax/ops/test_compare.py b/tests/jax/ops/test_compare.py index 9d0d63e4..f1948876 100644 --- a/tests/jax/ops/test_compare.py +++ b/tests/jax/ops/test_compare.py @@ -73,6 +73,7 @@ def less_or_equal(x: jax.Array, y: jax.Array) -> jax.Array: ], ) def test_compare(x_shape: tuple, y_shape: tuple): + # TODO record test properties once it is split. run_op_test_with_random_inputs(equal, [x_shape, y_shape]) run_op_test_with_random_inputs(not_equal, [x_shape, y_shape]) run_op_test_with_random_inputs(greater, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_concatenate.py b/tests/jax/ops/test_concatenate.py index bf377677..42e92a6f 100644 --- a/tests/jax/ops/test_concatenate.py +++ b/tests/jax/ops/test_concatenate.py @@ -2,10 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs +from utils import record_binary_op_test_properties @pytest.mark.parametrize( @@ -17,8 +20,14 @@ [(64, 64, 64, 64), (64, 64, 64, 64), 3], ], ) -def test_concatenate(x_shape: tuple, y_shape: tuple, axis: int): +def test_concatenate( + x_shape: tuple, y_shape: tuple, axis: int, record_tt_xla_property: Callable +): def concat(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.concatenate([x, y], axis=axis) + record_binary_op_test_properties( + record_tt_xla_property, "jax.numpy.add", "stablehlo.add" + ) + run_op_test_with_random_inputs(concat, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index 3587a4d5..6d552995 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -2,13 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax.numpy as jnp import pytest from infra import run_op_test +# TODO record test properties. + @pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -def test_constant_zeros(shape: tuple): +def test_constant_zeros(shape: tuple, record_tt_xla_property: Callable): def module_constant_zeros(): return jnp.zeros(shape) @@ -16,7 +20,7 @@ def module_constant_zeros(): @pytest.mark.parametrize("shape", [(32, 32), (1, 1)]) -def test_constant_ones(shape: tuple): +def test_constant_ones(shape: tuple, record_tt_xla_property: Callable): def module_constant_ones(): return jnp.ones(shape) @@ -24,7 +28,7 @@ def module_constant_ones(): @pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'") -def test_constant_multi_value(): +def test_constant_multi_value(record_tt_xla_property: Callable): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 252e62de..69746ef5 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.lax as jlx import jax.numpy as jnp import pytest from infra import random_tensor, run_op_test from jax._src.typing import DTypeLike +from utils import record_unary_op_test_properties # TODO we need to parametrize with all supported dtypes. @@ -30,10 +33,16 @@ "float64", ], ) -def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike): +def test_convert( + from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable +): def convert(x: jax.Array) -> jax.Array: return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype)) + record_unary_op_test_properties( + record_tt_xla_property, "jax.lax.convert_element_type", "stablehlo.convert" + ) + x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized. input = random_tensor(x_shape, dtype=from_dtype) diff --git a/tests/jax/ops/test_convolution.py b/tests/jax/ops/test_convolution.py index 9e2143c9..3d60d3d6 100644 --- a/tests/jax/ops/test_convolution.py +++ b/tests/jax/ops/test_convolution.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import pytest from infra import ComparisonConfig, random_tensor, run_op_test +from utils import record_op_test_properties # TODO investigate why conv has such poor precision. @@ -27,7 +30,10 @@ def comparison_config() -> ComparisonConfig: ], ) def test_conv1d( - img_shape: tuple, kernel_shape: tuple, comparison_config: ComparisonConfig + img_shape: tuple, + kernel_shape: tuple, + comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, ): def conv1d(img, weights): return jax.lax.conv_general_dilated( @@ -42,6 +48,13 @@ def conv1d(img, weights): batch_group_count=1, ) + record_op_test_properties( + record_tt_xla_property, + "Convolution", + "jax.lax.conv_general_dilated", + "stablehlo.convolution", + ) + img = random_tensor(img_shape, dtype="bfloat16") kernel = random_tensor(kernel_shape, dtype="bfloat16") @@ -103,6 +116,7 @@ def test_conv2d( stride_w: int, padding: int, comparison_config: ComparisonConfig, + record_tt_xla_property: Callable, ): def conv2d(img: jax.Array, kernel: jax.Array): return jax.lax.conv_general_dilated( @@ -113,6 +127,13 @@ def conv2d(img: jax.Array, kernel: jax.Array): dimension_numbers=("NHWC", "OIHW", "NHWC"), ) + record_op_test_properties( + record_tt_xla_property, + "Convolution", + "jax.lax.conv_general_dilated", + "stablehlo.convolution", + ) + img_shape = (batch_size, input_height, input_width, input_channels) kernel_shape = (output_channels, input_channels, filter_height, filter_width) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..7dcbde8c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable + +from conftest import RecordProperties + + +def record_unary_op_test_properties( + record_property: Callable, framework_op_name: str, op_name: str +): + record_property(RecordProperties.TEST_TYPE.value, "Op test") + record_property(RecordProperties.OP_KIND.value, "Unary op") + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_binary_op_test_properties( + record_property: Callable, framework_op_name: str, op_name: str +): + record_property(RecordProperties.TEST_TYPE.value, "Op test") + record_property(RecordProperties.OP_KIND.value, "Binary op") + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_op_test_properties( + record_property: Callable, op_kind: str, framework_op_name: str, op_name: str +): + record_property(RecordProperties.TEST_TYPE.value, "Op test") + record_property(RecordProperties.OP_KIND.value, op_kind) + record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) + record_property(RecordProperties.OP_NAME.value, op_name) + + +def record_model_test_properties(record_property: Callable, model_name: str): + record_property(RecordProperties.TEST_TYPE.value, "Model test") + record_property(RecordProperties.MODEL_NAME.value, model_name)