From 02379668051aa34ecd772082b83f6bf2b7003861 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Fri, 17 Jan 2025 14:40:28 +0000 Subject: [PATCH] First version --- tests/conftest.py | 99 ++++++++++++++++++++++++++ tests/jax/ops/test_abs.py | 9 ++- tests/jax/ops/test_add.py | 9 ++- tests/jax/ops/test_broadcast_in_dim.py | 11 ++- 4 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..922602b5 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from enum import Enum +from typing import Callable + +import pytest + + +class RecordProperties(Enum): + """Properties we can record.""" + + # Timestamp of test start. + START_TIMESTAMP = "start_timestamp" + # Timestamp of test end. + END_TIMESTAMP = "end_timestamp" + # Frontend or framework used to run the test. + FRONTEND = "frontend" + # Kind of operation. e.g. eltwise. + OP_KIND = "op_kind" + # Name of the operation in the framework. e.g. torch.conv2d. + FRAMEWORK_OP_NAME = "framework_op_name" + # Name of the operation. e.g. ttir.conv2d. + OP_NAME = "op_name" + # Name of the model in which this op appears. + MODEL_NAME = "model_name" + + +@pytest.fixture(scope="function", autouse=True) +def record_test_timestamp(record_property: Callable): + """ + Autouse fixture used to capture execution time of a test. + + Parameters: + ---------- + record_property: Callable + A pytest built-in function used to record test metadata, such as custom + properties or additional information about the test execution. + + Yields: + ------- + Callable + The `record_property` callable, allowing tests to add additional properties if + needed. + + + Example: + -------- + ``` + def test_model(fixture1, fixture2, ..., record_tt_xla_property): + record_tt_xla_property("key", value) + + # Test logic... + ``` + """ + start_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") + record_property(RecordProperties.START_TIMESTAMP.value, start_timestamp) + + # Run the test. + yield + + end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") + record_property(RecordProperties.END_TIMESTAMP.value, end_timestamp) + + +@pytest.fixture(scope="function", autouse=True) +def record_tt_xla_property(record_property: Callable): + """ + Autouse fixture that automatically records a property named 'frontend' with the + value 'tt-forge-fe' for each test function. + + Example: + + ``` + def test_model(fixture1, fixture2, ..., record_tt_xla_property): + record_tt_xla_property("key", value) + + # Test logic... + ``` + + Parameters: + ---------- + record_property: Callable + A pytest built-in function used to record test metadata, such as custom + properties or additional information about the test execution. + + Yields: + ------- + Callable + The `record_property` callable, allowing tests to add additional properties if + needed. + """ + # Record default properties for tt-xla. + record_property(RecordProperties.FRONTEND.value, "tt-xla") + + # Run the test. + yield record_property diff --git a/tests/jax/ops/test_abs.py b/tests/jax/ops/test_abs.py index d5c0827f..5af7cc47 100644 --- a/tests/jax/ops/test_abs.py +++ b/tests/jax/ops/test_abs.py @@ -2,16 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest +from conftest import RecordProperties from infra import run_op_test_with_random_inputs @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) -def test_abs(x_shape: tuple): +def test_abs(x_shape: tuple, record_tt_xla_property: Callable): def abs(x: jax.Array) -> jax.Array: return jnp.abs(x) + record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") + record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.abs") + record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.abs") + # Test both negative and positive values. run_op_test_with_random_inputs(abs, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_add.py b/tests/jax/ops/test_add.py index ce880a7c..bc9de970 100644 --- a/tests/jax/ops/test_add.py +++ b/tests/jax/ops/test_add.py @@ -2,9 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax import jax.numpy as jnp import pytest +from conftest import RecordProperties from infra import run_op_test_with_random_inputs @@ -15,8 +18,12 @@ [(64, 64), (64, 64)], ], ) -def test_add(x_shape: tuple, y_shape: tuple): +def test_add(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): def add(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.add(x, y) + record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") + record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.add") + record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.add") + run_op_test_with_random_inputs(add, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index ffda703d..59c8964c 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -2,8 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + import jax.numpy as jnp import pytest +from conftest import RecordProperties from infra import run_op_test_with_random_inputs @@ -11,8 +14,14 @@ @pytest.mark.xfail( reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16" ) -def test_broadcast_in_dim(input_shapes): +def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): def broadcast(a): return jnp.broadcast_to(a, (2, 4)) + record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary") + record_tt_xla_property( + RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.broadcast_to" + ) + record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.broadcast_in_dim") + run_op_test_with_random_inputs(broadcast, input_shapes)