-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
83c3808
commit d2dcb1e
Showing
5 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import jax | ||
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import NamedSharding | ||
from typing import Callable, Sequence | ||
|
||
from .base_tester import BaseTester | ||
from .comparison import ComparisonConfig | ||
from .device_runner import DeviceRunner | ||
from .workload import Workload | ||
|
||
|
||
class MultichipTester(BaseTester): | ||
"""Specific tester for ops.""" | ||
|
||
def __init__( | ||
self, | ||
mesh: jax.Mesh, | ||
in_specs: tuple, | ||
out_specs: jax.sharding.PartitionSpec, | ||
comparison_config: ComparisonConfig = ComparisonConfig(), | ||
) -> None: | ||
self.mesh = mesh | ||
self.in_specs = in_specs | ||
self.out_specs = out_specs | ||
super().__init__(comparison_config) | ||
|
||
def _compile_cpu( | ||
self, executable: Callable, static_argnames: Sequence[str] = None | ||
) -> Callable: | ||
"""Sets up `executable` for just-in-time compile - specifically for CPU.""" | ||
return jax.jit(executable, static_argnames=static_argnames) | ||
|
||
def _compile( | ||
self, executable: Callable, static_argnames: Sequence[str] = None | ||
) -> Callable: | ||
"""Sets up `executable` for just-in-time compile.""" | ||
module_sharded = shard_map( | ||
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs | ||
) | ||
output_sharding = NamedSharding(self.mesh, self.out_specs) | ||
return jax.jit( | ||
module_sharded, | ||
out_shardings=output_sharding, | ||
static_argnames=static_argnames, | ||
) | ||
|
||
def test(self, workload: Workload, cpu_workload: Workload) -> None: | ||
""" | ||
Runs test by running `workload` on TT device and CPU and comparing the results. | ||
""" | ||
compiled_executable = self._compile(workload.executable) | ||
cpu_compiled_executable = self._compile_cpu(cpu_workload.executable) | ||
|
||
cpu_compiled_workload = Workload( | ||
cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs | ||
) | ||
|
||
compiled_workload = Workload( | ||
compiled_executable, workload.args, workload.kwargs | ||
) | ||
|
||
non_sharded_workload = DeviceRunner.put_with_none_sharding( | ||
compiled_workload, self.mesh, in_specs=self.in_specs | ||
) | ||
|
||
tt_res = DeviceRunner.run_manual(non_sharded_workload) | ||
cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload) | ||
|
||
self._compare(tt_res, cpu_res) | ||
|
||
def test_with_random_inputs( | ||
self, | ||
f: Callable, | ||
golden_f: Callable, | ||
input_shapes: Sequence[tuple], | ||
minval: float = 0.0, | ||
maxval: float = 1.0, | ||
) -> None: | ||
""" | ||
Tests `f` by running it with random inputs in range [`minval`, `maxval`) on | ||
TT device and CPU and comparing the results. | ||
""" | ||
inputs = [ | ||
jax.random.uniform( | ||
key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval | ||
) | ||
for shape in input_shapes | ||
] | ||
workload = Workload(f, inputs) | ||
cpu_workload = Workload(golden_f, inputs) | ||
self.test(workload, cpu_workload) | ||
|
||
|
||
def run_multichip_test_with_random_inputs( | ||
mesh_test: Callable, | ||
golden_test: Callable, | ||
input_shapes: Sequence[tuple], | ||
mesh: jax.Mesh, | ||
in_specs: tuple, | ||
out_specs: jax.sharding.PartitionSpec, | ||
minval: float = 0.0, | ||
maxval: float = 1.0, | ||
comparison_config: ComparisonConfig = ComparisonConfig(), | ||
) -> None: | ||
""" | ||
Tests `mesh_test` with random inputs in range [`minval`, `maxval`) by running it on | ||
TT device and CPU and comparing the results based on `comparison_config`. | ||
""" | ||
tester = MultichipTester(mesh, in_specs, out_specs, comparison_config) | ||
tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import jit | ||
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import PartitionSpec | ||
from functools import partial | ||
from infra import run_multichip_test_with_random_inputs | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("x_shape", [(8192, 784)]) | ||
@pytest.mark.skip(reason="Compilation fails") | ||
def all_gather_test(x_shape: tuple): | ||
def fwd(batch): | ||
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True) | ||
return act | ||
|
||
def golden_fwd(batch): | ||
return jnp.tile(batch, (2, 1)) | ||
|
||
devices = jax.devices("tt") | ||
mesh = jax.make_mesh((2,), ("batch"), devices=devices) | ||
|
||
in_specs = (PartitionSpec("batch"),) | ||
out_specs = PartitionSpec("batch") | ||
|
||
run_multichip_test_with_random_inputs( | ||
fwd, golden_fwd, [x_shape], mesh, in_specs, out_specs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import jit | ||
from jax.experimental.shard_map import shard_map | ||
from jax.sharding import PartitionSpec | ||
from functools import partial | ||
from infra import run_multichip_test_with_random_inputs | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("x_shape", [(256, 256)]) | ||
@pytest.mark.skip(reason="Multichip still in development") | ||
def unary_eltwise_test(x_shape: tuple): | ||
def fwd(a_block): | ||
b_block = jnp.negative(a_block) | ||
stitched_result = jax.lax.psum(b_block, ("x", "y")) | ||
return stitched_result | ||
|
||
def fwd_single_device(a_block): | ||
a1, a2 = jnp.split(a_block, 2, axis=1) | ||
|
||
b1, b2 = jnp.negative(a1), jnp.negative(a2) | ||
|
||
stitched_result = b1 + b2 | ||
return stitched_result | ||
|
||
devices = jax.devices("tt") | ||
mesh = jax.make_mesh((1, 2), ("x", "y"), devices=devices) | ||
in_specs = (PartitionSpec("x", "y"),) | ||
out_specs = PartitionSpec(None, None) | ||
|
||
run_multichip_test_with_random_inputs( | ||
fwd, fwd_single_device, [x_shape], mesh, in_specs, out_specs | ||
) |