diff --git a/tests/infra/utils.py b/tests/infra/utils.py index f88a627a..f9fdf30a 100644 --- a/tests/infra/utils.py +++ b/tests/infra/utils.py @@ -51,7 +51,7 @@ def random_tensor( minval=int(minval), maxval=int(maxval), ) - else: + elif jnp.issubdtype(dtype_converted, jnp.floating): return jax.random.uniform( key=prng_key, shape=shape, @@ -59,6 +59,17 @@ def random_tensor( minval=minval, maxval=maxval, ) + elif jnp.issubdtype(dtype_converted, jnp.bool): + # Generate random tensor of 0s and 1s and interpret is as a bool tensor. + return jax.random.randint( + key=prng_key, + shape=shape, + dtype=jnp.int32, + minval=0, + maxval=1, + ).astype(dtype_converted) + else: + raise TypeError(f"Unsupported dtype: {dtype}") else: raise ValueError(f"Unsupported framework: {framework.value}.") diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index 61f2e24f..aca8c93a 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -12,9 +12,6 @@ @pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}") -@pytest.mark.xfail( - reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16" -) def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): def broadcast(a: jax.Array): return jnp.broadcast_to(a, (2, 4)) diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index e7bb5f39..0b4abd10 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import pytest from infra import run_op_test -from utils import record_op_test_properties +from utils import compile_fail, record_op_test_properties @pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") @@ -40,7 +40,7 @@ def module_constant_ones(): run_op_test(module_constant_ones, []) -@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'") +@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'")) def test_constant_multi_value(record_tt_xla_property: Callable): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 8349e0ca..d62d72b3 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -10,91 +10,213 @@ import pytest from infra import random_tensor, run_op_test from jax._src.typing import DTypeLike -from utils import record_unary_op_test_properties +from utils import compile_fail, record_unary_op_test_properties, runtime_fail -# Allow 64bit precision in jax which is disabled by default. -jax.config.update("jax_enable_x64", True) +from tests.utils import enable_x64 # NOTE Use test_data_types.py as reference for all supported data types. +def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): + """ + Helper function which checks dtype combination and skips if unsupported for some + reason. + + Extracted here in order not to pollute the test function. + """ + # ---------- Atol comparison failed ---------- + + if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.float64 and to_dtype == jnp.uint16: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.float32 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.bfloat16 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16." + ) + ) + + # ---------- Cannot get the device from a tensor with host storage ---------- + + if from_dtype == jnp.uint64 and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if from_dtype in [jnp.int16, jnp.int32, jnp.int64] and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if from_dtype == jnp.float64 and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if to_dtype in [jnp.float32, jnp.float64] and from_dtype in [ + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if to_dtype == jnp.bfloat16 and from_dtype in [ + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + # ---------- Executable expected parameter x of size y but got... ---------- + + if ( + from_dtype in [jnp.uint16, jnp.uint32, jnp.float32, jnp.bfloat16] + and to_dtype == jnp.float64 + ): + pytest.xfail( + compile_fail( + "Executable expected parameter 0 of size 8192 but got buffer with " + "incompatible size 4096 (https://github.com/tenstorrent/tt-xla/issues/170)" + ) + ) + + @pytest.mark.parametrize( "from_dtype", [ + # uints + pytest.param( + jnp.uint8, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), jnp.uint16, jnp.uint32, + jnp.uint64, + # ints pytest.param( - jnp.uint64, - marks=pytest.mark.skip( - reason=( - "Cannot get the device from a tensor with host storage. " - "See issue https://github.com/tenstorrent/tt-xla/issues/171" - ) - ), - ), - pytest.param( - jnp.int16, - marks=pytest.mark.skip( - reason=( - "Cannot get the device from a tensor with host storage. " - "See issue https://github.com/tenstorrent/tt-xla/issues/171" - ) - ), + jnp.int8, + marks=pytest.mark.skip(reason="Unsupported data type"), ), + jnp.int16, + jnp.int32, + jnp.int64, + # floats pytest.param( - jnp.int32, - marks=pytest.mark.skip( - reason=( - "Cannot get the device from a tensor with host storage. " - "See issue https://github.com/tenstorrent/tt-xla/issues/171" - ) - ), + jnp.float16, + marks=pytest.mark.skip(reason="Unsupported data type"), ), + jnp.float32, + jnp.float64, + # bfloat + jnp.bfloat16, + # bool pytest.param( - jnp.int64, + jnp.bool, marks=pytest.mark.skip( - reason=( - "Cannot get the device from a tensor with host storage. " - "See issue https://github.com/tenstorrent/tt-xla/issues/171" - ) + reason="Causes segfaults. Should be investigated separately." ), ), - jnp.float32, - jnp.bfloat16, ], ) @pytest.mark.parametrize( "to_dtype", [ + # uints pytest.param( - jnp.uint16, - marks=pytest.mark.skip( - reason=( - "Fails due to low comparison metrics. " - "See issue https://github.com/tenstorrent/tt-xla/issues/172" - ) - ), + jnp.uint8, + marks=pytest.mark.skip(reason="Unsupported data type"), ), + jnp.uint16, jnp.uint32, jnp.uint64, + # ints pytest.param( - jnp.int16, - marks=pytest.mark.skip( - reason=( - "Fails due to low comparison metrics. " - "See issue https://github.com/tenstorrent/tt-xla/issues/172" - ) - ), + jnp.int8, + marks=pytest.mark.skip(reason="Unsupported data type"), ), + jnp.int16, jnp.int32, jnp.int64, + # floats + pytest.param( + jnp.float16, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), jnp.float32, + jnp.float64, + # bfloat jnp.bfloat16, + # bool + pytest.param( + jnp.bool, + marks=pytest.mark.skip( + reason="Causes segfaults. Should be investigated separately." + ), + ), ], ) -@pytest.mark.skip( - f"Skipped unconditionally due to many fails. There is ongoing work on rewriting these tests." -) def test_convert( from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable ): @@ -107,7 +229,14 @@ def convert(x: jax.Array) -> jax.Array: "stablehlo.convert", ) - x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized. - input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0) + # Some dtype conversions are not supported. Check and decide whether to skip or + # proceed. + conditionally_skip(from_dtype, to_dtype) + + # Shape does not make any impact here, thus not parametrized. + x_shape = (32, 32) + + with enable_x64(): + input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0) - run_op_test(convert, [input]) + run_op_test(convert, [input]) diff --git a/tests/jax/test_data_types.py b/tests/jax/test_data_types.py index 6ace5d16..2f1a26e5 100644 --- a/tests/jax/test_data_types.py +++ b/tests/jax/test_data_types.py @@ -12,13 +12,13 @@ from infra import run_op_test from jax._src.typing import DTypeLike -# Allow 64bit precision in jax which is disabled by default. -jax.config.update("jax_enable_x64", True) +from tests.utils import enable_x64 @pytest.mark.parametrize( "dtype", [ + # uints pytest.param( jnp.uint8, marks=pytest.mark.xfail(reason="Unsupported data type"), @@ -26,6 +26,7 @@ jnp.uint16, jnp.uint32, jnp.uint64, + # ints pytest.param( jnp.int8, marks=pytest.mark.xfail(reason="Unsupported data type"), @@ -33,6 +34,7 @@ jnp.int16, jnp.int32, jnp.int64, + # floats pytest.param( jnp.float16, marks=pytest.mark.xfail(reason="Unsupported data type"), @@ -40,7 +42,7 @@ jnp.float32, pytest.param( jnp.float64, - marks=pytest.mark.skip( + marks=pytest.mark.xfail( reason=( "Executable expected parameter 0 of size 8 but got buffer " "with incompatible size 4. See issue " @@ -48,18 +50,22 @@ ) ), ), + # bfloat jnp.bfloat16, + # bool + jnp.bool, ], ) def test_dtypes(dtype: DTypeLike): def scalar() -> jax.Array: """ This test just returns a scalar of a certain dtype. It will fail if dtype is - unsupported. + unsupported. In mlir graph, it produces one simple stablehlo.constant op. Scalars are actually 0-dim arrays. They can be created the same way arrays are, using `jax.array(, dtype)` or using `dtype()`. """ return jnp.array(1, dtype) # same as dtype(1) - run_op_test(scalar, []) + with enable_x64(): + run_op_test(scalar, []) diff --git a/tests/utils.py b/tests/utils.py index e111f038..73ea98fb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,8 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from typing import Callable +import jax from conftest import RecordProperties @@ -41,3 +43,20 @@ def record_op_test_properties( def record_model_test_properties(record_property: Callable, model_name: str): record_property(RecordProperties.MODEL_NAME.value, model_name) + + +@contextmanager +def enable_x64(): + """ + Context manager that temporarily enables x64 in jax.config. + + Isolated as a context manager so that it doesn't change global config for all jax + imports and cause unexpected fails elsewhere. + """ + try: + # Set the config to True within this block, and yield back control. + jax.config.update("jax_enable_x64", True) + yield + finally: + # After `with` statement ends, turn it off again. + jax.config.update("jax_enable_x64", False)