From d2dcb1e9325c466041424858743017f67df426a6 Mon Sep 17 00:00:00 2001 From: Andrej Jakovljevic Date: Mon, 3 Feb 2025 11:32:30 +0000 Subject: [PATCH] Adding first test --- tests/infra/__init__.py | 1 + tests/infra/device_runner.py | 53 +++++++++ tests/infra/multichip_tester.py | 116 ++++++++++++++++++++ tests/jax/multichip/manual/all_gather.py | 33 ++++++ tests/jax/multichip/manual/unary_eltwise.py | 38 +++++++ 5 files changed, 241 insertions(+) create mode 100644 tests/infra/multichip_tester.py create mode 100644 tests/jax/multichip/manual/all_gather.py create mode 100644 tests/jax/multichip/manual/unary_eltwise.py diff --git a/tests/infra/__init__.py b/tests/infra/__init__.py index 8e7183b5..aec7b36f 100644 --- a/tests/infra/__init__.py +++ b/tests/infra/__init__.py @@ -8,3 +8,4 @@ from .model_tester import ModelTester, RunMode from .op_tester import run_op_test, run_op_test_with_random_inputs from .utils import random_tensor, supported_dtypes +from .multichip_tester import run_multichip_test_with_random_inputs diff --git a/tests/infra/device_runner.py b/tests/infra/device_runner.py index 08551893..47b06ba4 100644 --- a/tests/infra/device_runner.py +++ b/tests/infra/device_runner.py @@ -10,6 +10,8 @@ from .types import Tensor from .workload import Workload +from jax.sharding import Mesh, PartitionSpec, NamedSharding + import inspect @@ -18,6 +20,11 @@ class DeviceRunner: Class providing methods to put and run workload on any supported device. """ + @staticmethod + def run_manual(workload: Workload) -> Tensor: + """Runs `workload` on TT device.""" + return DeviceRunner._run_manual(workload) + @staticmethod def run_on_tt_device(workload: Workload, device_num: int = 0) -> Tensor: """Runs `workload` on TT device.""" @@ -63,6 +70,43 @@ def put_tensors_on_gpu(*tensors: Tensor) -> Sequence[Tensor]: """Puts `tensors` on GPU.""" raise NotImplementedError("Support for GPUs not implemented") + @staticmethod + def put_with_none_sharding( + workload: Workload, + mesh: jax.sharding.Mesh, + in_specs: Sequence[jax.sharding.PartitionSpec], + ) -> Tensor: + """Gives inputs shardings for multichip workloads""" + args_on_device = [] + spec_index = 0 + for arg in workload.args: + if not isinstance(arg, Tensor): + args_on_device.append(arg) + else: + args_on_device.append( + DeviceRunner._put_tensor_none_sharding( + arg, mesh, in_specs[spec_index] + ) + ) + spec_index += 1 + + kwargs_on_device = {} + for key, value in workload.kwargs.items(): + if not isinstance(value, Tensor): + kwargs_on_device[key] = value + else: + kwargs_on_device[key] = DeviceRunner._put_tensor_none_sharding( + value, mesh, in_specs[spec_index] + ) + spec_index += 1 + + return Workload(workload.executable, args_on_device, kwargs_on_device) + + @staticmethod + def _run_manual(workload: Workload) -> Tensor: + """Runs `workload` on a device.""" + return workload.execute().block_until_ready() + @staticmethod def _run_on_device( workload: Workload, device_type: DeviceType, device_num: int = 0 @@ -74,6 +118,15 @@ def _run_on_device( with jax.default_device(device): return device_workload.execute() + @staticmethod + def _put_tensor_none_sharding( + tensor: Tensor, mesh: jax.sharding.Mesh, in_spec: jax.sharding.PartitionSpec + ) -> Tensor: + """Needed for multichip: Uses put_device to give inputs shardings.""" + none_tuple = (None,) * len(in_spec) + none_spec = PartitionSpec(*none_tuple) + return jax.device_put(tensor, NamedSharding(mesh, none_spec), may_alias=True) + @staticmethod def _put_on_device( workload: Workload, device_type: DeviceType, device_num: int = 0 diff --git a/tests/infra/multichip_tester.py b/tests/infra/multichip_tester.py new file mode 100644 index 00000000..fe2775a2 --- /dev/null +++ b/tests/infra/multichip_tester.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import jax +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding +from typing import Callable, Sequence + +from .base_tester import BaseTester +from .comparison import ComparisonConfig +from .device_runner import DeviceRunner +from .workload import Workload + + +class MultichipTester(BaseTester): + """Specific tester for ops.""" + + def __init__( + self, + mesh: jax.Mesh, + in_specs: tuple, + out_specs: jax.sharding.PartitionSpec, + comparison_config: ComparisonConfig = ComparisonConfig(), + ) -> None: + self.mesh = mesh + self.in_specs = in_specs + self.out_specs = out_specs + super().__init__(comparison_config) + + def _compile_cpu( + self, executable: Callable, static_argnames: Sequence[str] = None + ) -> Callable: + """Sets up `executable` for just-in-time compile - specifically for CPU.""" + return jax.jit(executable, static_argnames=static_argnames) + + def _compile( + self, executable: Callable, static_argnames: Sequence[str] = None + ) -> Callable: + """Sets up `executable` for just-in-time compile.""" + module_sharded = shard_map( + executable, mesh=self.mesh, in_specs=self.in_specs, out_specs=self.out_specs + ) + output_sharding = NamedSharding(self.mesh, self.out_specs) + return jax.jit( + module_sharded, + out_shardings=output_sharding, + static_argnames=static_argnames, + ) + + def test(self, workload: Workload, cpu_workload: Workload) -> None: + """ + Runs test by running `workload` on TT device and CPU and comparing the results. + """ + compiled_executable = self._compile(workload.executable) + cpu_compiled_executable = self._compile_cpu(cpu_workload.executable) + + cpu_compiled_workload = Workload( + cpu_compiled_executable, cpu_workload.args, cpu_workload.kwargs + ) + + compiled_workload = Workload( + compiled_executable, workload.args, workload.kwargs + ) + + non_sharded_workload = DeviceRunner.put_with_none_sharding( + compiled_workload, self.mesh, in_specs=self.in_specs + ) + + tt_res = DeviceRunner.run_manual(non_sharded_workload) + cpu_res = DeviceRunner.run_on_cpu(cpu_compiled_workload) + + self._compare(tt_res, cpu_res) + + def test_with_random_inputs( + self, + f: Callable, + golden_f: Callable, + input_shapes: Sequence[tuple], + minval: float = 0.0, + maxval: float = 1.0, + ) -> None: + """ + Tests `f` by running it with random inputs in range [`minval`, `maxval`) on + TT device and CPU and comparing the results. + """ + inputs = [ + jax.random.uniform( + key=jax.random.key(0), shape=shape, minval=minval, maxval=maxval + ) + for shape in input_shapes + ] + workload = Workload(f, inputs) + cpu_workload = Workload(golden_f, inputs) + self.test(workload, cpu_workload) + + +def run_multichip_test_with_random_inputs( + mesh_test: Callable, + golden_test: Callable, + input_shapes: Sequence[tuple], + mesh: jax.Mesh, + in_specs: tuple, + out_specs: jax.sharding.PartitionSpec, + minval: float = 0.0, + maxval: float = 1.0, + comparison_config: ComparisonConfig = ComparisonConfig(), +) -> None: + """ + Tests `mesh_test` with random inputs in range [`minval`, `maxval`) by running it on + TT device and CPU and comparing the results based on `comparison_config`. + """ + tester = MultichipTester(mesh, in_specs, out_specs, comparison_config) + tester.test_with_random_inputs(mesh_test, golden_test, input_shapes, minval, maxval) diff --git a/tests/jax/multichip/manual/all_gather.py b/tests/jax/multichip/manual/all_gather.py new file mode 100644 index 00000000..bb2ccc41 --- /dev/null +++ b/tests/jax/multichip/manual/all_gather.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import jax +import jax.numpy as jnp +from jax import jit +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec +from functools import partial +from infra import run_multichip_test_with_random_inputs +import pytest + + +@pytest.mark.parametrize("x_shape", [(8192, 784)]) +@pytest.mark.skip(reason="Compilation fails") +def all_gather_test(x_shape: tuple): + def fwd(batch): + act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True) + return act + + def golden_fwd(batch): + return jnp.tile(batch, (2, 1)) + + devices = jax.devices("tt") + mesh = jax.make_mesh((2,), ("batch"), devices=devices) + + in_specs = (PartitionSpec("batch"),) + out_specs = PartitionSpec("batch") + + run_multichip_test_with_random_inputs( + fwd, golden_fwd, [x_shape], mesh, in_specs, out_specs + ) diff --git a/tests/jax/multichip/manual/unary_eltwise.py b/tests/jax/multichip/manual/unary_eltwise.py new file mode 100644 index 00000000..f1903737 --- /dev/null +++ b/tests/jax/multichip/manual/unary_eltwise.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import jax +import jax.numpy as jnp +from jax import jit +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec +from functools import partial +from infra import run_multichip_test_with_random_inputs +import pytest + + +@pytest.mark.parametrize("x_shape", [(256, 256)]) +@pytest.mark.skip(reason="Multichip still in development") +def unary_eltwise_test(x_shape: tuple): + def fwd(a_block): + b_block = jnp.negative(a_block) + stitched_result = jax.lax.psum(b_block, ("x", "y")) + return stitched_result + + def fwd_single_device(a_block): + a1, a2 = jnp.split(a_block, 2, axis=1) + + b1, b2 = jnp.negative(a1), jnp.negative(a2) + + stitched_result = b1 + b2 + return stitched_result + + devices = jax.devices("tt") + mesh = jax.make_mesh((1, 2), ("x", "y"), devices=devices) + in_specs = (PartitionSpec("x", "y"),) + out_specs = PartitionSpec(None, None) + + run_multichip_test_with_random_inputs( + fwd, fwd_single_device, [x_shape], mesh, in_specs, out_specs + )