Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrote tests for scalars and dtypes. Wrote tests for ranks. #168

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes
from .utils import random_tensor
48 changes: 34 additions & 14 deletions tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import jax
import jax.numpy as jnp
from jax import export
from jax._src.typing import DTypeLike

from .device_runner import run_on_cpu
from .types import Framework, Tensor
from .workload import Workload

# List of all data types that runtime currently supports.
supported_dtypes = [jnp.float32, jnp.bfloat16, jnp.uint32, jnp.uint16]


def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
"""Convert a string dtype to the corresponding framework-specific dtype."""
Expand All @@ -26,7 +25,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
@run_on_cpu
def random_tensor(
shape: tuple,
dtype: str = "float32",
dtype: Union[str, DTypeLike] = jnp.float32,
random_seed: int = 0,
minval: float = 0.0,
maxval: float = 1.0,
Expand All @@ -36,20 +35,41 @@ def random_tensor(
Generates a random tensor of `shape`, `dtype`, and `random_seed` in range
[`minval`, `maxval`) for the desired `framework`.
"""
# Convert dtype string to actual dtype for the selected framework.
dtype_converted = _str_to_dtype(dtype, framework)
dtype_converted = (
_str_to_dtype(dtype, framework) if isinstance(dtype, str) else dtype
)

# Generate random tensor based on framework type
# Generate random tensor based on framework type.
if framework == Framework.JAX:
prng_key = jax.random.PRNGKey(random_seed)

return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
if jnp.issubdtype(dtype_converted, jnp.integer):
return jax.random.randint(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=int(minval),
maxval=int(maxval),
)
elif jnp.issubdtype(dtype_converted, jnp.floating):
return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
elif jnp.issubdtype(dtype_converted, jnp.bool):
# Generate random tensor of 0s and 1s and interpret is as a bool tensor.
return jax.random.randint(
key=prng_key,
shape=shape,
dtype=jnp.int32,
minval=0,
maxval=1,
).astype(dtype_converted)
else:
raise TypeError(f"Unsupported dtype: {dtype}")
else:
raise ValueError(f"Unsupported framework: {framework.value}.")

Expand Down
3 changes: 0 additions & 3 deletions tests/jax/ops/test_broadcast_in_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@


@pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}")
@pytest.mark.xfail(
reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16"
)
Comment on lines -15 to -17
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test unexpectedly passes now. @mrakitaTT
Same happened with test_simple_regression in #174

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most reasonable explanation is what @sgligorijevicTT suggested that metal might have some non-determinism issues, but we should keep this on our radar and monitor if it happens again in the future.

def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable):
def broadcast(a: jax.Array):
return jnp.broadcast_to(a, (2, 4))
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/ops/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import pytest
from infra import run_op_test
from utils import record_op_test_properties
from utils import compile_fail, record_op_test_properties


@pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}")
Expand Down Expand Up @@ -40,7 +40,7 @@ def module_constant_ones():
run_op_test(module_constant_ones, [])


@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'")
@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'"))
def test_constant_multi_value(record_tt_xla_property: Callable):
def module_constant_multi():
return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
Expand Down
224 changes: 206 additions & 18 deletions tests/jax/ops/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,233 @@
import pytest
from infra import random_tensor, run_op_test
from jax._src.typing import DTypeLike
from utils import record_unary_op_test_properties
from utils import compile_fail, record_unary_op_test_properties, runtime_fail

from tests.utils import enable_x64

# NOTE Use test_data_types.py as reference for all supported data types.


def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike):
"""
Helper function which checks dtype combination and skips if unsupported for some
reason.

Extracted here in order not to pollute the test function.
"""
# ---------- Atol comparison failed ----------

if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16."
)
)

if from_dtype == jnp.float64 and to_dtype == jnp.uint16:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16."
)
)

if from_dtype == jnp.float32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16."
)
)

if from_dtype == jnp.bfloat16 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16."
)
)

# ---------- Cannot get the device from a tensor with host storage ----------

if from_dtype == jnp.uint64 and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if from_dtype in [jnp.int16, jnp.int32, jnp.int64] and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if from_dtype == jnp.float64 and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if to_dtype in [jnp.float32, jnp.float64] and from_dtype in [
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if to_dtype == jnp.bfloat16 and from_dtype in [
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

# ---------- Executable expected parameter x of size y but got... ----------

if (
from_dtype in [jnp.uint16, jnp.uint32, jnp.float32, jnp.bfloat16]
and to_dtype == jnp.float64
):
pytest.xfail(
compile_fail(
"Executable expected parameter 0 of size 8192 but got buffer with "
"incompatible size 4096 (https://github.com/tenstorrent/tt-xla/issues/170)"
)
)


# TODO we need to parametrize with all supported dtypes.
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
# uints
pytest.param(
jnp.uint8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.uint16,
jnp.uint32,
jnp.uint64,
# ints
pytest.param(
jnp.int8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.int16,
jnp.int32,
jnp.int64,
# floats
pytest.param(
jnp.float16,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.float32,
jnp.float64,
# bfloat
jnp.bfloat16,
# bool
pytest.param(
jnp.bool,
marks=pytest.mark.skip(
reason="Causes segfaults. Should be investigated separately."
),
),
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
# uints
pytest.param(
jnp.uint8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.uint16,
jnp.uint32,
jnp.uint64,
# ints
pytest.param(
jnp.int8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.int16,
jnp.int32,
jnp.int64,
# floats
pytest.param(
jnp.float16,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.float32,
jnp.float64,
# bfloat
jnp.bfloat16,
# bool
pytest.param(
jnp.bool,
marks=pytest.mark.skip(
reason="Causes segfaults. Should be investigated separately."
),
),
],
)
@pytest.mark.skip(
f"Skipped unconditionally due to many fails. There is ongoing work on rewriting these tests."
)
def test_convert(
from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable
):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
return jlx.convert_element_type(x, new_dtype=to_dtype)

record_unary_op_test_properties(
record_tt_xla_property,
"jax.lax.convert_element_type",
"stablehlo.convert",
)

x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
# Some dtype conversions are not supported. Check and decide whether to skip or
# proceed.
conditionally_skip(from_dtype, to_dtype)

# Shape does not make any impact here, thus not parametrized.
x_shape = (32, 32)

with enable_x64():
input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0)

run_op_test(convert, [input])
run_op_test(convert, [input])
Loading
Loading