Skip to content

Commit

Permalink
First version
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Jan 23, 2025
1 parent 370b251 commit 0237966
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 3 deletions.
99 changes: 99 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from datetime import datetime
from enum import Enum
from typing import Callable

import pytest


class RecordProperties(Enum):
"""Properties we can record."""

# Timestamp of test start.
START_TIMESTAMP = "start_timestamp"
# Timestamp of test end.
END_TIMESTAMP = "end_timestamp"
# Frontend or framework used to run the test.
FRONTEND = "frontend"
# Kind of operation. e.g. eltwise.
OP_KIND = "op_kind"
# Name of the operation in the framework. e.g. torch.conv2d.
FRAMEWORK_OP_NAME = "framework_op_name"
# Name of the operation. e.g. ttir.conv2d.
OP_NAME = "op_name"
# Name of the model in which this op appears.
MODEL_NAME = "model_name"


@pytest.fixture(scope="function", autouse=True)
def record_test_timestamp(record_property: Callable):
"""
Autouse fixture used to capture execution time of a test.
Parameters:
----------
record_property: Callable
A pytest built-in function used to record test metadata, such as custom
properties or additional information about the test execution.
Yields:
-------
Callable
The `record_property` callable, allowing tests to add additional properties if
needed.
Example:
--------
```
def test_model(fixture1, fixture2, ..., record_tt_xla_property):
record_tt_xla_property("key", value)
# Test logic...
```
"""
start_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z")
record_property(RecordProperties.START_TIMESTAMP.value, start_timestamp)

# Run the test.
yield

end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z")
record_property(RecordProperties.END_TIMESTAMP.value, end_timestamp)


@pytest.fixture(scope="function", autouse=True)
def record_tt_xla_property(record_property: Callable):
"""
Autouse fixture that automatically records a property named 'frontend' with the
value 'tt-forge-fe' for each test function.
Example:
```
def test_model(fixture1, fixture2, ..., record_tt_xla_property):
record_tt_xla_property("key", value)
# Test logic...
```
Parameters:
----------
record_property: Callable
A pytest built-in function used to record test metadata, such as custom
properties or additional information about the test execution.
Yields:
-------
Callable
The `record_property` callable, allowing tests to add additional properties if
needed.
"""
# Record default properties for tt-xla.
record_property(RecordProperties.FRONTEND.value, "tt-xla")

# Run the test.
yield record_property
9 changes: 8 additions & 1 deletion tests/jax/ops/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import jax
import jax.numpy as jnp
import pytest
from conftest import RecordProperties
from infra import run_op_test_with_random_inputs


@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)])
def test_abs(x_shape: tuple):
def test_abs(x_shape: tuple, record_tt_xla_property: Callable):
def abs(x: jax.Array) -> jax.Array:
return jnp.abs(x)

record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary")
record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.abs")
record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.abs")

# Test both negative and positive values.
run_op_test_with_random_inputs(abs, [x_shape], minval=-5.0, maxval=5.0)
9 changes: 8 additions & 1 deletion tests/jax/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import jax
import jax.numpy as jnp
import pytest
from conftest import RecordProperties
from infra import run_op_test_with_random_inputs


Expand All @@ -15,8 +18,12 @@
[(64, 64), (64, 64)],
],
)
def test_add(x_shape: tuple, y_shape: tuple):
def test_add(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable):
def add(x: jax.Array, y: jax.Array) -> jax.Array:
return jnp.add(x, y)

record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary")
record_tt_xla_property(RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.add")
record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.add")

run_op_test_with_random_inputs(add, [x_shape, y_shape])
11 changes: 10 additions & 1 deletion tests/jax/ops/test_broadcast_in_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import jax.numpy as jnp
import pytest
from conftest import RecordProperties
from infra import run_op_test_with_random_inputs


@pytest.mark.parametrize("input_shapes", [[(2, 1)]])
@pytest.mark.xfail(
reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16"
)
def test_broadcast_in_dim(input_shapes):
def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable):
def broadcast(a):
return jnp.broadcast_to(a, (2, 4))

record_tt_xla_property(RecordProperties.OP_KIND.value, "Eltwise unary")
record_tt_xla_property(
RecordProperties.FRAMEWORK_OP_NAME.value, "jax.numpy.broadcast_to"
)
record_tt_xla_property(RecordProperties.OP_NAME.value, "stablehlo.broadcast_in_dim")

run_op_test_with_random_inputs(broadcast, input_shapes)

0 comments on commit 0237966

Please sign in to comment.