-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rewrote tests for scalars and dtypes. Wrote tests for ranks.
Fixed #135.
- Loading branch information
1 parent
a1154f5
commit bacd7b4
Showing
6 changed files
with
229 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,65 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Union | ||
""" | ||
This file contains sanity tests which create arrays of various dtypes, in order not to | ||
parametrize each test additionally with dtypes. | ||
""" | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
from infra import run_op_test, supported_dtypes | ||
from infra import run_op_test | ||
from jax._src.typing import DTypeLike | ||
|
||
# Convenience alias. | ||
scalar = Union[int, float] | ||
|
||
|
||
@pytest.mark.parametrize("dtype", supported_dtypes) | ||
@pytest.mark.skip( | ||
"Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" | ||
) | ||
def test_scalar_dtype(dtype: DTypeLike): | ||
""" | ||
This test just returns a scalar of a certain dtype. It will fail if dtype is | ||
unsupported. | ||
""" | ||
|
||
def add(x: scalar) -> scalar: | ||
return x | ||
|
||
in0 = dtype(1) # Dummy scalar used as input. | ||
run_op_test(add, [in0]) | ||
# Allow 64bit precision in jax which is disabled by default. | ||
jax.config.update("jax_enable_x64", True) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", supported_dtypes) | ||
@pytest.mark.skip( | ||
"Passes locally but fails on CI due to AssertionError: Unexpected XLA layout override" | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
pytest.param( | ||
jnp.uint8, | ||
marks=pytest.mark.xfail(reason="Unsupported data type"), | ||
), | ||
jnp.uint16, | ||
jnp.uint32, | ||
jnp.uint64, | ||
pytest.param( | ||
jnp.int8, | ||
marks=pytest.mark.xfail(reason="Unsupported data type"), | ||
), | ||
jnp.int16, | ||
jnp.int32, | ||
jnp.int64, | ||
pytest.param( | ||
jnp.float16, | ||
marks=pytest.mark.xfail(reason="Unsupported data type"), | ||
), | ||
jnp.float32, | ||
pytest.param( | ||
jnp.float64, | ||
marks=pytest.mark.skip( | ||
reason=( | ||
"Executable expected parameter 0 of size 8 but got buffer " | ||
"with incompatible size 4. See issue " | ||
"https://github.com/tenstorrent/tt-xla/issues/170" | ||
) | ||
), | ||
), | ||
jnp.bfloat16, | ||
], | ||
) | ||
def test_array_dtype(dtype: DTypeLike): | ||
""" | ||
This test just returns an array of a certain dtype. It will fail if dtype is | ||
unsupported. | ||
""" | ||
def test_dtypes(dtype: DTypeLike): | ||
def scalar() -> jax.Array: | ||
""" | ||
This test just returns a scalar of a certain dtype. It will fail if dtype is | ||
unsupported. | ||
def array(x: jax.Array) -> jax.Array: | ||
return x | ||
Scalars are actually 0-dim arrays. They can be created the same way arrays are, | ||
using `jax.array(<some-value>, dtype)` or using `dtype(<some-value>)`. | ||
""" | ||
return jnp.array(1, dtype) # same as dtype(1) | ||
|
||
in0 = jnp.ones((32, 32), dtype) # Dummy array used as input. | ||
run_op_test(array, [in0]) | ||
run_op_test(scalar, []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,66 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
This file contains sanity tests for some representative ops to make sure they work for | ||
various ranks, in order not to parametrize each test additionally with ranks. | ||
""" | ||
|
||
# TODO this file should contain sanity tests for some representative ops to make sure | ||
# they work for higher ranks, in order not to parametrize each test additionally with | ||
# ranks. See issue https://github.com/tenstorrent/tt-xla/issues/135. | ||
import jax | ||
import pytest | ||
from infra import run_op_test_with_random_inputs | ||
from jax import numpy as jnp | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"x_shape", | ||
[ | ||
pytest.param( | ||
(), | ||
marks=pytest.mark.skip( | ||
reason=( | ||
"Unexpected XLA layout override. " | ||
"See issue https://github.com/tenstorrent/tt-xla/issues/173" | ||
) | ||
), | ||
), | ||
(32,), | ||
(32, 32), | ||
(1, 32, 32), | ||
(1, 3, 32, 32), | ||
], | ||
) | ||
def test_unary_op(x_shape: tuple): | ||
"""Using negative as it is trivial, since this test only focuses on ranks.""" | ||
|
||
def negate(x: jax.Array) -> jax.Array: | ||
return jnp.negative(x) | ||
|
||
run_op_test_with_random_inputs(negate, [x_shape]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
pytest.param( | ||
(), | ||
marks=pytest.mark.skip( | ||
reason=( | ||
"Unexpected XLA layout override. " | ||
"See issue https://github.com/tenstorrent/tt-xla/issues/173" | ||
) | ||
), | ||
), | ||
(32,), | ||
(32, 32), | ||
(1, 32, 32), | ||
(1, 3, 32, 32), | ||
], | ||
) | ||
def test_binary_op(shape: tuple): | ||
"""Using add as it is trivial, since this test only focuses on ranks.""" | ||
|
||
def add(x: jax.Array, y: jax.Array) -> jax.Array: | ||
return x + y | ||
|
||
run_op_test_with_random_inputs(add, [shape, shape]) |
Oops, something went wrong.