Skip to content

Commit

Permalink
Rewrote tests for scalars and dtypes. Wrote tests for ranks.
Browse files Browse the repository at this point in the history
Fixed #135.
  • Loading branch information
kmitrovicTT committed Jan 17, 2025
1 parent a1154f5 commit bacd7b4
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 119 deletions.
2 changes: 1 addition & 1 deletion tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes
from .utils import random_tensor
37 changes: 23 additions & 14 deletions tests/infra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import jax
import jax.numpy as jnp
from jax import export
from jax._src.typing import DTypeLike

from .device_runner import run_on_cpu
from .types import Framework, Tensor
from .workload import Workload

# List of all data types that runtime currently supports.
supported_dtypes = [jnp.float32, jnp.bfloat16, jnp.uint32, jnp.uint16]


def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
"""Convert a string dtype to the corresponding framework-specific dtype."""
Expand All @@ -26,7 +25,7 @@ def _str_to_dtype(dtype_str: str, framework: Framework = Framework.JAX):
@run_on_cpu
def random_tensor(
shape: tuple,
dtype: str = "float32",
dtype: Union[str, DTypeLike] = jnp.float32,
random_seed: int = 0,
minval: float = 0.0,
maxval: float = 1.0,
Expand All @@ -36,20 +35,30 @@ def random_tensor(
Generates a random tensor of `shape`, `dtype`, and `random_seed` in range
[`minval`, `maxval`) for the desired `framework`.
"""
# Convert dtype string to actual dtype for the selected framework.
dtype_converted = _str_to_dtype(dtype, framework)
dtype_converted = (
_str_to_dtype(dtype, framework) if isinstance(dtype, str) else dtype
)

# Generate random tensor based on framework type
# Generate random tensor based on framework type.
if framework == Framework.JAX:
prng_key = jax.random.PRNGKey(random_seed)

return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
if jnp.issubdtype(dtype_converted, jnp.integer):
return jax.random.randint(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=int(minval),
maxval=int(maxval),
)
else:
return jax.random.uniform(
key=prng_key,
shape=shape,
dtype=dtype_converted,
minval=minval,
maxval=maxval,
)
else:
raise ValueError(f"Unsupported framework: {framework.value}.")

Expand Down
83 changes: 71 additions & 12 deletions tests/jax/ops/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,91 @@
from infra import random_tensor, run_op_test
from jax._src.typing import DTypeLike

# Allow 64bit precision in jax which is disabled by default.
jax.config.update("jax_enable_x64", True)

# NOTE Use test_data_types.py as reference for all supported data types.


# TODO we need to parametrize with all supported dtypes.
@pytest.mark.parametrize(
"from_dtype",
[
"bfloat16",
"float32",
jnp.uint16,
jnp.uint32,
pytest.param(
jnp.uint64,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
),
pytest.param(
jnp.int16,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
),
pytest.param(
jnp.int32,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
),
pytest.param(
jnp.int64,
marks=pytest.mark.skip(
reason=(
"Cannot get the device from a tensor with host storage. "
"See issue https://github.com/tenstorrent/tt-xla/issues/171"
)
),
),
jnp.float32,
jnp.bfloat16,
],
)
@pytest.mark.parametrize(
"to_dtype",
[
"uint32",
"uint64",
"int32",
"int64",
"bfloat16",
"float32",
"float64",
pytest.param(
jnp.uint16,
marks=pytest.mark.skip(
reason=(
"Fails due to low comparison metrics. "
"See issue https://github.com/tenstorrent/tt-xla/issues/172"
)
),
),
jnp.uint32,
jnp.uint64,
pytest.param(
jnp.int16,
marks=pytest.mark.skip(
reason=(
"Fails due to low comparison metrics. "
"See issue https://github.com/tenstorrent/tt-xla/issues/172"
)
),
),
jnp.int32,
jnp.int64,
jnp.float32,
jnp.bfloat16,
],
)
def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike):
def convert(x: jax.Array) -> jax.Array:
return jlx.convert_element_type(x, new_dtype=jnp.dtype(to_dtype))
return jlx.convert_element_type(x, new_dtype=to_dtype)

x_shape = (32, 32) # Shape does not make any impact here, thus not parametrized.
input = random_tensor(x_shape, dtype=from_dtype)
input = random_tensor(x_shape, from_dtype, minval=0.0, maxval=10.0)

run_op_test(convert, [input])
85 changes: 51 additions & 34 deletions tests/jax/test_data_types.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,65 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Union
"""
This file contains sanity tests which create arrays of various dtypes, in order not to
parametrize each test additionally with dtypes.
"""

import jax
import jax.numpy as jnp
import pytest
from infra import run_op_test, supported_dtypes
from infra import run_op_test
from jax._src.typing import DTypeLike

# Convenience alias.
scalar = Union[int, float]


@pytest.mark.parametrize("dtype", supported_dtypes)
@pytest.mark.skip(
"Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override"
)
def test_scalar_dtype(dtype: DTypeLike):
"""
This test just returns a scalar of a certain dtype. It will fail if dtype is
unsupported.
"""

def add(x: scalar) -> scalar:
return x

in0 = dtype(1) # Dummy scalar used as input.
run_op_test(add, [in0])
# Allow 64bit precision in jax which is disabled by default.
jax.config.update("jax_enable_x64", True)


@pytest.mark.parametrize("dtype", supported_dtypes)
@pytest.mark.skip(
"Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override"
@pytest.mark.parametrize(
"dtype",
[
pytest.param(
jnp.uint8,
marks=pytest.mark.xfail(reason="Unsupported data type"),
),
jnp.uint16,
jnp.uint32,
jnp.uint64,
pytest.param(
jnp.int8,
marks=pytest.mark.xfail(reason="Unsupported data type"),
),
jnp.int16,
jnp.int32,
jnp.int64,
pytest.param(
jnp.float16,
marks=pytest.mark.xfail(reason="Unsupported data type"),
),
jnp.float32,
pytest.param(
jnp.float64,
marks=pytest.mark.skip(
reason=(
"Executable expected parameter 0 of size 8 but got buffer "
"with incompatible size 4. See issue "
"https://github.com/tenstorrent/tt-xla/issues/170"
)
),
),
jnp.bfloat16,
],
)
def test_array_dtype(dtype: DTypeLike):
"""
This test just returns an array of a certain dtype. It will fail if dtype is
unsupported.
"""
def test_dtypes(dtype: DTypeLike):
def scalar() -> jax.Array:
"""
This test just returns a scalar of a certain dtype. It will fail if dtype is
unsupported.
def array(x: jax.Array) -> jax.Array:
return x
Scalars are actually 0-dim arrays. They can be created the same way arrays are,
using `jax.array(<some-value>, dtype)` or using `dtype(<some-value>)`.
"""
return jnp.array(1, dtype) # same as dtype(1)

in0 = jnp.ones((32, 32), dtype) # Dummy array used as input.
run_op_test(array, [in0])
run_op_test(scalar, [])
65 changes: 62 additions & 3 deletions tests/jax/test_ranks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,66 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
"""
This file contains sanity tests for some representative ops to make sure they work for
various ranks, in order not to parametrize each test additionally with ranks.
"""

# TODO this file should contain sanity tests for some representative ops to make sure
# they work for higher ranks, in order not to parametrize each test additionally with
# ranks. See issue https://github.com/tenstorrent/tt-xla/issues/135.
import jax
import pytest
from infra import run_op_test_with_random_inputs
from jax import numpy as jnp


@pytest.mark.parametrize(
"x_shape",
[
pytest.param(
(),
marks=pytest.mark.skip(
reason=(
"Unexpected XLA layout override. "
"See issue https://github.com/tenstorrent/tt-xla/issues/173"
)
),
),
(32,),
(32, 32),
(1, 32, 32),
(1, 3, 32, 32),
],
)
def test_unary_op(x_shape: tuple):
"""Using negative as it is trivial, since this test only focuses on ranks."""

def negate(x: jax.Array) -> jax.Array:
return jnp.negative(x)

run_op_test_with_random_inputs(negate, [x_shape])


@pytest.mark.parametrize(
"shape",
[
pytest.param(
(),
marks=pytest.mark.skip(
reason=(
"Unexpected XLA layout override. "
"See issue https://github.com/tenstorrent/tt-xla/issues/173"
)
),
),
(32,),
(32, 32),
(1, 32, 32),
(1, 3, 32, 32),
],
)
def test_binary_op(shape: tuple):
"""Using add as it is trivial, since this test only focuses on ranks."""

def add(x: jax.Array, y: jax.Array) -> jax.Array:
return x + y

run_op_test_with_random_inputs(add, [shape, shape])
Loading

0 comments on commit bacd7b4

Please sign in to comment.