Skip to content

Commit

Permalink
Adding first test
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 4, 2025
1 parent 83c3808 commit d2dcb1e
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .model_tester import ModelTester, RunMode
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes
from .multichip_tester import run_multichip_test_with_random_inputs
53 changes: 53 additions & 0 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .types import Tensor
from .workload import Workload

from jax.sharding import Mesh, PartitionSpec, NamedSharding

import inspect


Expand All @@ -18,6 +20,11 @@ class DeviceRunner:
Class providing methods to put and run workload on any supported device.
"""

@staticmethod
def run_manual(workload: Workload) -> Tensor:
"""Runs `workload` on TT device."""
return DeviceRunner._run_manual(workload)

@staticmethod
def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor:
"""Runs `workload` on TT device."""
Expand Down Expand Up @@ -63,6 +70,43 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]:
"""Puts `tensors` on GPU."""
raise NotImplementedError("Support for GPUs not implemented")

@staticmethod
def put_with_none_sharding(
workload: Workload,
mesh: jax.sharding.Mesh,
in_specs: Sequence[jax.sharding.PartitionSpec],
) -> Tensor:
"""Gives inputs shardings for multichip workloads"""
args_on_device = []
spec_index = 0
for arg in workload.args:
if not isinstance(arg, Tensor):
args_on_device.append(arg)
else:
args_on_device.append(
DeviceRunner._put_tensor_none_sharding(
arg, mesh, in_specs[spec_index]
)
)
spec_index += 1

kwargs_on_device = {}
for key, value in workload.kwargs.items():
if not isinstance(value, Tensor):
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding(
value, mesh, in_specs[spec_index]
)
spec_index += 1

return Workload(workload.executable, args_on_device, kwargs_on_device)

@staticmethod
def _run_manual(workload: Workload) -> Tensor:
"""Runs `workload` on a device."""
return workload.execute().block_until_ready()

@staticmethod
def _run_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
Expand All @@ -74,6 +118,15 @@ def _run_on_device(
with jax.default_device(device):
return device_workload.execute()

@staticmethod
def _put_tensor_none_sharding(
tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec
) -> Tensor:
"""Needed for multichip: Uses put_device to give inputs shardings."""
none_tuple = (None,) * len(in_spec)
none_spec = PartitionSpec(*none_tuple)
return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True)

@staticmethod
def _put_on_device(
workload: Workload, device_type: DeviceType, device_num: int = 0
Expand Down
116 changes: 116 additions & 0 deletions tests/infra/multichip_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import jax
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from typing import Callable, Sequence

from .base_tester import BaseTester
from .comparison import ComparisonConfig
from .device_runner import DeviceRunner
from .workload import Workload


class MultichipTester(BaseTester):
"""Specific tester for ops."""

def __init__(
self,
mesh: jax.Mesh,
in_specs: tuple,
out_specs: jax.sharding.PartitionSpec,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
self.mesh = mesh
self.in_specs = in_specs
self.out_specs = out_specs
super().__init__(comparison_config)

def _compile_cpu(
self, executable: Callable, static_argnames: Sequence[str] = None
) -> Callable:
"""Sets up `executable` for just-in-time compile - specifically for CPU."""
return jax.jit(executable, static_argnames=static_argnames)

def _compile(
self, executable: Callable, static_argnames: Sequence[str] = None
) -> Callable:
"""Sets up `executable` for just-in-time compile."""
module_sharded = shard_map(
executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs
)
output_sharding = NamedSharding(self.mesh, self.out_specs)
return jax.jit(
module_sharded,
out_shardings=output_sharding,
static_argnames=static_argnames,
)

def test(self, workload: Workload, cpu_workload: Workload) -> None:
"""
Runs test by running `workload` on TT device and CPU and comparing the results.
"""
compiled_executable = self._compile(workload.executable)
cpu_compiled_executable = self._compile_cpu(cpu_workload.executable)

cpu_compiled_workload = Workload(
cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs
)

compiled_workload = Workload(
compiled_executable, workload.args, workload.kwargs
)

non_sharded_workload = DeviceRunner.put_with_none_sharding(
compiled_workload, self.mesh, in_specs=self.in_specs
)

tt_res = DeviceRunner.run_manual(non_sharded_workload)
cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload)

self._compare(tt_res, cpu_res)

def test_with_random_inputs(
self,
f: Callable,
golden_f: Callable,
input_shapes: Sequence[tuple],
minval: float = 0.0,
maxval: float = 1.0,
) -> None:
"""
Tests `f` by running it with random inputs in range [`minval`, `maxval`) on
TT device and CPU and comparing the results.
"""
inputs = [
jax.random.uniform(
key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval
)
for shape in input_shapes
]
workload = Workload(f, inputs)
cpu_workload = Workload(golden_f, inputs)
self.test(workload, cpu_workload)


def run_multichip_test_with_random_inputs(
mesh_test: Callable,
golden_test: Callable,
input_shapes: Sequence[tuple],
mesh: jax.Mesh,
in_specs: tuple,
out_specs: jax.sharding.PartitionSpec,
minval: float = 0.0,
maxval: float = 1.0,
comparison_config: ComparisonConfig = ComparisonConfig(),
) -> None:
"""
Tests `mesh_test` with random inputs in range [`minval`, `maxval`) by running it on
TT device and CPU and comparing the results based on `comparison_config`.
"""
tester = MultichipTester(mesh, in_specs, out_specs, comparison_config)
tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval)
33 changes: 33 additions & 0 deletions tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import jax
import jax.numpy as jnp
from jax import jit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec
from functools import partial
from infra import run_multichip_test_with_random_inputs
import pytest


@pytest.mark.parametrize("x_shape", [(8192, 784)])
@pytest.mark.skip(reason="Compilation fails")
def all_gather_test(x_shape: tuple):
def fwd(batch):
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True)
return act

def golden_fwd(batch):
return jnp.tile(batch, (2, 1))

devices = jax.devices("tt")
mesh = jax.make_mesh((2,), ("batch"), devices=devices)

in_specs = (PartitionSpec("batch"),)
out_specs = PartitionSpec("batch")

run_multichip_test_with_random_inputs(
fwd, golden_fwd, [x_shape], mesh, in_specs, out_specs
)
38 changes: 38 additions & 0 deletions tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import jax
import jax.numpy as jnp
from jax import jit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec
from functools import partial
from infra import run_multichip_test_with_random_inputs
import pytest


@pytest.mark.parametrize("x_shape", [(256, 256)])
@pytest.mark.skip(reason="Multichip still in development")
def unary_eltwise_test(x_shape: tuple):
def fwd(a_block):
b_block = jnp.negative(a_block)
stitched_result = jax.lax.psum(b_block, ("x", "y"))
return stitched_result

def fwd_single_device(a_block):
a1, a2 = jnp.split(a_block, 2, axis=1)

b1, b2 = jnp.negative(a1), jnp.negative(a2)

stitched_result = b1 + b2
return stitched_result

devices = jax.devices("tt")
mesh = jax.make_mesh((1, 2), ("x", "y"), devices=devices)
in_specs = (PartitionSpec("x", "y"),)
out_specs = PartitionSpec(None, None)

run_multichip_test_with_random_inputs(
fwd, fwd_single_device, [x_shape], mesh, in_specs, out_specs
)

0 comments on commit d2dcb1e

Please sign in to comment.