diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py index 8e7183b5..126f1278 100644 --- a/tests/infra/__init__.py +++ b/tests/infra/__init__.py @@ -7,4 +7,4 @@ from .graph_tester import run_graph_test, run_graph_test_with_random_inputs from .model_tester import ModelTester, RunMode from .op_tester import run_op_test, run_op_test_with_random_inputs -from .utils import random_tensor, supported_dtypes +from .utils import random_tensor diff --git a/tests/infra/utils.py b/tests/infra/utils.py index 1823fec5..f9fdf30a 100644 --- a/tests/infra/utils.py +++ b/tests/infra/utils.py @@ -2,18 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 +from typing import Union import jax import jax.numpy as jnp from jax import export +from jax._src.typing import DTypeLike from .device_runner import run_on_cpu from .types import Framework, Tensor from .workload import Workload -# List of all data types that runtime currently supports. -supported_dtypes = [jnp.float32, jnp.bfloat16, jnp.uint32, jnp.uint16] - def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): """Convert a string dtype to the corresponding framework-specific dtype.""" @@ -26,7 +25,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX): @run_on_cpu def random_tensor( shape: tuple, - dtype: str = "float32", + dtype: Union[str, DTypeLike] = jnp.float32, random_seed: int = 0, minval: float = 0.0, maxval: float = 1.0, @@ -36,20 +35,41 @@ def random_tensor( Generates a random tensor of `shape`, `dtype`, and `random_seed` in range [`minval`, `maxval`) for the desired `framework`. """ - # Convert dtype string to actual dtype for the selected framework. - dtype_converted = _str_to_dtype(dtype, framework) + dtype_converted = ( + _str_to_dtype(dtype, framework) if isinstance(dtype, str) else dtype + ) - # Generate random tensor based on framework type + # Generate random tensor based on framework type. if framework == Framework.JAX: prng_key = jax.random.PRNGKey(random_seed) - return jax.random.uniform( - key=prng_key, - shape=shape, - dtype=dtype_converted, - minval=minval, - maxval=maxval, - ) + if jnp.issubdtype(dtype_converted, jnp.integer): + return jax.random.randint( + key=prng_key, + shape=shape, + dtype=dtype_converted, + minval=int(minval), + maxval=int(maxval), + ) + elif jnp.issubdtype(dtype_converted, jnp.floating): + return jax.random.uniform( + key=prng_key, + shape=shape, + dtype=dtype_converted, + minval=minval, + maxval=maxval, + ) + elif jnp.issubdtype(dtype_converted, jnp.bool): + # Generate random tensor of 0s and 1s and interpret is as a bool tensor. + return jax.random.randint( + key=prng_key, + shape=shape, + dtype=jnp.int32, + minval=0, + maxval=1, + ).astype(dtype_converted) + else: + raise TypeError(f"Unsupported dtype: {dtype}") else: raise ValueError(f"Unsupported framework: {framework.value}.") diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index 61f2e24f..aca8c93a 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -12,9 +12,6 @@ @pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}") -@pytest.mark.xfail( - reason="AssertionError: Atol comparison failed. Calculated: atol=0.804124116897583. Required: atol=0.16" -) def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): def broadcast(a: jax.Array): return jnp.broadcast_to(a, (2, 4)) diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index e7bb5f39..0b4abd10 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import pytest from infra import run_op_test -from utils import record_op_test_properties +from utils import compile_fail, record_op_test_properties @pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") @@ -40,7 +40,7 @@ def module_constant_ones(): run_op_test(module_constant_ones, []) -@pytest.mark.xfail(reason="failed to legalize operation 'ttir.constant'") +@pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'")) def test_constant_multi_value(record_tt_xla_property: Callable): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index 99ba4385..d62d72b3 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -10,37 +10,218 @@ import pytest from infra import random_tensor, run_op_test from jax._src.typing import DTypeLike -from utils import record_unary_op_test_properties +from utils import compile_fail, record_unary_op_test_properties, runtime_fail + +from tests.utils import enable_x64 + +# NOTE Use test_data_types.py as reference for all supported data types. + + +def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): + """ + Helper function which checks dtype combination and skips if unsupported for some + reason. + + Extracted here in order not to pollute the test function. + """ + # ---------- Atol comparison failed ---------- + + if from_dtype == jnp.uint32 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.float64 and to_dtype == jnp.uint16: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=9.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.float32 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16." + ) + ) + + if from_dtype == jnp.bfloat16 and to_dtype in [jnp.uint16, jnp.int16]: + pytest.xfail( + runtime_fail( + "AssertionError: Atol comparison failed. Calculated: atol=1.0. Required: atol=0.16." + ) + ) + + # ---------- Cannot get the device from a tensor with host storage ---------- + + if from_dtype == jnp.uint64 and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if from_dtype in [jnp.int16, jnp.int32, jnp.int64] and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if from_dtype == jnp.float64 and to_dtype in [ + jnp.uint16, + jnp.uint32, + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if to_dtype in [jnp.float32, jnp.float64] and from_dtype in [ + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + if to_dtype == jnp.bfloat16 and from_dtype in [ + jnp.uint64, + jnp.int16, + jnp.int32, + jnp.int64, + jnp.float64, + ]: + pytest.xfail( + runtime_fail( + "Cannot get the device from a tensor with host storage " + "(https://github.com/tenstorrent/tt-xla/issues/171)" + ) + ) + + # ---------- Executable expected parameter x of size y but got... ---------- + + if ( + from_dtype in [jnp.uint16, jnp.uint32, jnp.float32, jnp.bfloat16] + and to_dtype == jnp.float64 + ): + pytest.xfail( + compile_fail( + "Executable expected parameter 0 of size 8192 but got buffer with " + "incompatible size 4096 (https://github.com/tenstorrent/tt-xla/issues/170)" + ) + ) -# TODO we need to parametrize with all supported dtypes. @pytest.mark.parametrize( "from_dtype", [ - "bfloat16", - "float32", + # uints + pytest.param( + jnp.uint8, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.uint16, + jnp.uint32, + jnp.uint64, + # ints + pytest.param( + jnp.int8, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.int16, + jnp.int32, + jnp.int64, + # floats + pytest.param( + jnp.float16, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.float32, + jnp.float64, + # bfloat + jnp.bfloat16, + # bool + pytest.param( + jnp.bool, + marks=pytest.mark.skip( + reason="Causes segfaults. Should be investigated separately." + ), + ), ], ) @pytest.mark.parametrize( "to_dtype", [ - "uint32", - "uint64", - "int32", - "int64", - "bfloat16", - "float32", - "float64", + # uints + pytest.param( + jnp.uint8, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.uint16, + jnp.uint32, + jnp.uint64, + # ints + pytest.param( + jnp.int8, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.int16, + jnp.int32, + jnp.int64, + # floats + pytest.param( + jnp.float16, + marks=pytest.mark.skip(reason="Unsupported data type"), + ), + jnp.float32, + jnp.float64, + # bfloat + jnp.bfloat16, + # bool + pytest.param( + jnp.bool, + marks=pytest.mark.skip( + reason="Causes segfaults. Should be investigated separately." + ), + ), ], ) -@pytest.mark.skip( - f"Skipped unconditionally due to many fails. There is ongoing work on rewriting these tests." -) def test_convert( from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable ): def convert(x: jax.Array) -> jax.Array: - return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype)) + return jlx.convert_element_type(x, new_dtype=to_dtype) record_unary_op_test_properties( record_tt_xla_property, @@ -48,7 +229,14 @@ def convert(x: jax.Array) -> jax.Array: "stablehlo.convert", ) - x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized. - input = random_tensor(x_shape, dtype=from_dtype) + # Some dtype conversions are not supported. Check and decide whether to skip or + # proceed. + conditionally_skip(from_dtype, to_dtype) + + # Shape does not make any impact here, thus not parametrized. + x_shape = (32, 32) + + with enable_x64(): + input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0) - run_op_test(convert, [input]) + run_op_test(convert, [input]) diff --git a/tests/jax/test_data_types.py b/tests/jax/test_data_types.py index 1b9a5a9f..2f1a26e5 100644 --- a/tests/jax/test_data_types.py +++ b/tests/jax/test_data_types.py @@ -1,48 +1,71 @@ # SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 - -from typing import Union +""" +This file contains sanity tests which create arrays of various dtypes, in order not to +parametrize each test additionally with dtypes. +""" import jax import jax.numpy as jnp import pytest -from infra import run_op_test, supported_dtypes +from infra import run_op_test from jax._src.typing import DTypeLike -# Convenience alias. -scalar = Union[int, float] - - -@pytest.mark.parametrize("dtype", supported_dtypes) -@pytest.mark.skip( - "Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" -) -def test_scalar_dtype(dtype: DTypeLike): - """ - This test just returns a scalar of a certain dtype. It will fail if dtype is - unsupported. - """ - - def add(x: scalar) -> scalar: - return x - - in0 = dtype(1) # Dummy scalar used as input. - run_op_test(add, [in0]) +from tests.utils import enable_x64 -@pytest.mark.parametrize("dtype", supported_dtypes) -@pytest.mark.skip( - "Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" +@pytest.mark.parametrize( + "dtype", + [ + # uints + pytest.param( + jnp.uint8, + marks=pytest.mark.xfail(reason="Unsupported data type"), + ), + jnp.uint16, + jnp.uint32, + jnp.uint64, + # ints + pytest.param( + jnp.int8, + marks=pytest.mark.xfail(reason="Unsupported data type"), + ), + jnp.int16, + jnp.int32, + jnp.int64, + # floats + pytest.param( + jnp.float16, + marks=pytest.mark.xfail(reason="Unsupported data type"), + ), + jnp.float32, + pytest.param( + jnp.float64, + marks=pytest.mark.xfail( + reason=( + "Executable expected parameter 0 of size 8 but got buffer " + "with incompatible size 4. See issue " + "https://github.com/tenstorrent/tt-xla/issues/170" + ) + ), + ), + # bfloat + jnp.bfloat16, + # bool + jnp.bool, + ], ) -def test_array_dtype(dtype: DTypeLike): - """ - This test just returns an array of a certain dtype. It will fail if dtype is - unsupported. - """ +def test_dtypes(dtype: DTypeLike): + def scalar() -> jax.Array: + """ + This test just returns a scalar of a certain dtype. It will fail if dtype is + unsupported. In mlir graph, it produces one simple stablehlo.constant op. - def array(x: jax.Array) -> jax.Array: - return x + Scalars are actually 0-dim arrays. They can be created the same way arrays are, + using `jax.array(, dtype)` or using `dtype()`. + """ + return jnp.array(1, dtype) # same as dtype(1) - in0 = jnp.ones((32, 32), dtype) # Dummy array used as input. - run_op_test(array, [in0]) + with enable_x64(): + run_op_test(scalar, []) diff --git a/tests/jax/test_ranks.py b/tests/jax/test_ranks.py index 426775b7..f0c70356 100644 --- a/tests/jax/test_ranks.py +++ b/tests/jax/test_ranks.py @@ -1,7 +1,66 @@ # SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 +""" +This file contains sanity tests for some representative ops to make sure they work for +various ranks, in order not to parametrize each test additionally with ranks. +""" -# TODO this file should contain sanity tests for some representative ops to make sure -# they work for higher ranks, in order not to parametrize each test additionally with -# ranks. See issue https://github.com/tenstorrent/tt-xla/issues/135. +import jax +import pytest +from infra import run_op_test_with_random_inputs +from jax import numpy as jnp + + +@pytest.mark.parametrize( + "x_shape", + [ + pytest.param( + (), + marks=pytest.mark.skip( + reason=( + "Unexpected XLA layout override. " + "See issue https://github.com/tenstorrent/tt-xla/issues/173" + ) + ), + ), + (32,), + (32, 32), + (1, 32, 32), + (1, 3, 32, 32), + ], +) +def test_unary_op(x_shape: tuple): + """Using negative as it is trivial, since this test only focuses on ranks.""" + + def negate(x: jax.Array) -> jax.Array: + return jnp.negative(x) + + run_op_test_with_random_inputs(negate, [x_shape]) + + +@pytest.mark.parametrize( + "shape", + [ + pytest.param( + (), + marks=pytest.mark.skip( + reason=( + "Unexpected XLA layout override. " + "See issue https://github.com/tenstorrent/tt-xla/issues/173" + ) + ), + ), + (32,), + (32, 32), + (1, 32, 32), + (1, 3, 32, 32), + ], +) +def test_binary_op(shape: tuple): + """Using add as it is trivial, since this test only focuses on ranks.""" + + def add(x: jax.Array, y: jax.Array) -> jax.Array: + return x + y + + run_op_test_with_random_inputs(add, [shape, shape]) diff --git a/tests/jax/test_scalar_types.py b/tests/jax/test_scalar_types.py index 4a62e854..943b3a5f 100644 --- a/tests/jax/test_scalar_types.py +++ b/tests/jax/test_scalar_types.py @@ -1,74 +1,40 @@ # SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 - -from typing import Union +""" +This file contains sanity tests if scalars work as expected, since they represent a +special case of 0-dim arrays. +""" import jax -import jax.numpy as jnp import pytest -from infra import run_op_test, supported_dtypes -from jax._src.typing import DTypeLike - -# Convenience alias. -scalar = Union[int, float] +from infra import run_op_test +from jax import numpy as jnp -# Convenience alias. -scalar_or_array = Union[scalar, jax.Array] +def test_scalar_scalar_add(): + """Tests adding two scalars.""" -@pytest.mark.parametrize( - ["dtype0", "dtype1"], - [(d0, d1) for d0 in supported_dtypes for d1 in supported_dtypes], -) -@pytest.mark.skip( - "Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" -) -def test_scalar_scalar_add(dtype0: DTypeLike, dtype1: DTypeLike): - """ - Tests adding of two scalars. - - Adding two ints causes huge atol differences. It is known that tt-metal does not - work well with ints. Adding an int and a float works due to upcast to float. - """ + def add() -> jax.Array: + return jnp.array(1, jnp.float32) + jnp.array(2, jnp.float32) - def add(x: scalar, y: scalar) -> scalar: - return x + y + run_op_test(add, []) - in0, in1 = dtype0(1), dtype1(2) - if in0.dtype in [jnp.uint32, jnp.uint16] and in1.dtype in [jnp.uint32, jnp.uint16]: - pytest.skip("Adding two ints causes huge atol differences.") - - run_op_test(add, [in0, in1]) - - -@pytest.mark.parametrize( - ["in0", "in1"], - [ - # Scalar and 0-dim array. - [jnp.array(1.0, jnp.float32), jnp.float32(2.0)], - [jnp.float32(2.0), jnp.array(1.0, jnp.float32)], - # Scalar and 1-dim array. - [jnp.ones((32,), jnp.float32), jnp.float32(2.0)], - [jnp.float32(2.0), jnp.ones((32,), jnp.float32)], - # Scalar and 2-dim array. - [jnp.ones((1, 32), jnp.float32), jnp.float32(2.0)], - [jnp.float32(2.0), jnp.ones((1, 32), jnp.float32)], - ], -) -@pytest.mark.skip( - "Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" -) -def test_scalar_array_add(in0: scalar_or_array, in1: scalar_or_array): +@pytest.mark.skip("Fails due to https://github.com/tenstorrent/tt-metal/issues/16701") +def test_scalar_array_add(): """ - Tests adding of scalar and an array. + Tests adding scalar and an array. Also performs a sanity check that addition is commutative, that is scalar + array is the same as array + scalar. """ - def add(x: scalar_or_array, y: scalar_or_array) -> scalar_or_array: - return x + y + def array_plus_scalar() -> jax.Array: + return jnp.ones((32, 32), jnp.float32) + jnp.array(2.0, jnp.float32) + + def scalar_plus_array() -> jax.Array: + return jnp.array(2.0, jnp.float32) + jnp.ones((32, 32), jnp.float32) - run_op_test(add, [in0, in1]) + run_op_test(array_plus_scalar, []) + run_op_test(scalar_plus_array, []) diff --git a/tests/utils.py b/tests/utils.py index e111f038..73ea98fb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,8 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from typing import Callable +import jax from conftest import RecordProperties @@ -41,3 +43,20 @@ def record_op_test_properties( def record_model_test_properties(record_property: Callable, model_name: str): record_property(RecordProperties.MODEL_NAME.value, model_name) + + +@contextmanager +def enable_x64(): + """ + Context manager that temporarily enables x64 in jax.config. + + Isolated as a context manager so that it doesn't change global config for all jax + imports and cause unexpected fails elsewhere. + """ + try: + # Set the config to True within this block, and yield back control. + jax.config.update("jax_enable_x64", True) + yield + finally: + # After `with` statement ends, turn it off again. + jax.config.update("jax_enable_x64", False)