Skip to content

Commit

Permalink
Code review
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Feb 5, 2025
1 parent e52deba commit 9050fab
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 64 deletions.
13 changes: 12 additions & 1 deletion tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,25 @@ def random_tensor(
minval=int(minval),
maxval=int(maxval),
)
else:
elif jnp.issubdtype(dtype_converted, jnp.floating):
return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
elif jnp.issubdtype(dtype_converted, jnp.bool):
# Generate random tensor of 0s and 1s and interpret is as a bool tensor.
return jax.random.randint(
key=prng_key,
shape=shape,
dtype=jnp.int32,
minval=0,
maxval=1,
).astype(dtype_converted)
else:
raise TypeError(f"Unsupported dtype: {dtype}")
else:
raise ValueError(f"Unsupported framework: {framework.value}.")

Expand Down
3 changes: 0 additions & 3 deletions tests/jax/ops/test_broadcast_in_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@


@pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}")
@pytest.mark.xfail(
reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16"
)
def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable):
def broadcast(a: jax.Array):
return jnp.broadcast_to(a, (2, 4))
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/ops/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import pytest
from infra import run_op_test
from utils import record_op_test_properties
from utils import compile_fail, record_op_test_properties


@pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}")
Expand Down Expand Up @@ -40,7 +40,7 @@ def module_constant_ones():
run_op_test(module_constant_ones, [])


@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'")
@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'"))
def test_constant_multi_value(record_tt_xla_property: Callable):
def module_constant_multi():
return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
Expand Down
235 changes: 182 additions & 53 deletions tests/jax/ops/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,91 +10,213 @@
import pytest
from infra import random_tensor, run_op_test
from jax._src.typing import DTypeLike
from utils import record_unary_op_test_properties
from utils import compile_fail, record_unary_op_test_properties, runtime_fail

# Allow 64bit precision in jax which is disabled by default.
jax.config.update("jax_enable_x64", True)
from tests.utils import enable_x64

# NOTE Use test_data_types.py as reference for all supported data types.


def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike):
"""
Helper function which checks dtype combination and skips if unsupported for some
reason.
Extracted here in order not to pollute the test function.
"""
# ---------- Atol comparison failed ----------

if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16."
)
)

if from_dtype == jnp.float64 and to_dtype == jnp.uint16:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16."
)
)

if from_dtype == jnp.float32 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16."
)
)

if from_dtype == jnp.bfloat16 and to_dtype in [jnp.uint16, jnp.int16]:
pytest.xfail(
runtime_fail(
"AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16."
)
)

# ---------- Cannot get the device from a tensor with host storage ----------

if from_dtype == jnp.uint64 and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if from_dtype in [jnp.int16, jnp.int32, jnp.int64] and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if from_dtype == jnp.float64 and to_dtype in [
jnp.uint16,
jnp.uint32,
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if to_dtype in [jnp.float32, jnp.float64] and from_dtype in [
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

if to_dtype == jnp.bfloat16 and from_dtype in [
jnp.uint64,
jnp.int16,
jnp.int32,
jnp.int64,
jnp.float64,
]:
pytest.xfail(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)

# ---------- Executable expected parameter x of size y but got... ----------

if (
from_dtype in [jnp.uint16, jnp.uint32, jnp.float32, jnp.bfloat16]
and to_dtype == jnp.float64
):
pytest.xfail(
compile_fail(
"Executable expected parameter 0 of size 8192 but got buffer with "
"incompatible size 4096 (https://github.com/tenstorrent/tt-xla/issues/170)"
)
)


@pytest.mark.parametrize(
"from_dtype",
[
# uints
pytest.param(
jnp.uint8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.uint16,
jnp.uint32,
jnp.uint64,
# ints
pytest.param(
jnp.uint64,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
),
pytest.param(
jnp.int16,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
jnp.int8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.int16,
jnp.int32,
jnp.int64,
# floats
pytest.param(
jnp.int32,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
jnp.float16,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.float32,
jnp.float64,
# bfloat
jnp.bfloat16,
# bool
pytest.param(
jnp.int64,
jnp.bool,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
reason="Causes segfaults. Should be investigated separately."
),
),
jnp.float32,
jnp.bfloat16,
],
)
@pytest.mark.parametrize(
"to_dtype",
[
# uints
pytest.param(
jnp.uint16,
marks=pytest.mark.skip(
reason=(
"Fails due to low comparison metrics. "
"See issue https://github.com/tenstorrent/tt-xla/issues/172"
)
),
jnp.uint8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.uint16,
jnp.uint32,
jnp.uint64,
# ints
pytest.param(
jnp.int16,
marks=pytest.mark.skip(
reason=(
"Fails due to low comparison metrics. "
"See issue https://github.com/tenstorrent/tt-xla/issues/172"
)
),
jnp.int8,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.int16,
jnp.int32,
jnp.int64,
# floats
pytest.param(
jnp.float16,
marks=pytest.mark.skip(reason="Unsupported data type"),
),
jnp.float32,
jnp.float64,
# bfloat
jnp.bfloat16,
# bool
pytest.param(
jnp.bool,
marks=pytest.mark.skip(
reason="Causes segfaults. Should be investigated separately."
),
),
],
)
@pytest.mark.skip(
f"Skipped unconditionally due to many fails. There is ongoing work on rewriting these tests."
)
def test_convert(
from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable
):
Expand All @@ -107,7 +229,14 @@ def convert(x: jax.Array) -> jax.Array:
"stablehlo.convert",
)

x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0)
# Some dtype conversions are not supported. Check and decide whether to skip or
# proceed.
conditionally_skip(from_dtype, to_dtype)

# Shape does not make any impact here, thus not parametrized.
x_shape = (32, 32)

with enable_x64():
input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0)

run_op_test(convert, [input])
run_op_test(convert, [input])
Loading

0 comments on commit 9050fab

Please sign in to comment.