Skip to content

Commit

Permalink
[JAX] Generate more readable error for failed device deserialization …
Browse files Browse the repository at this point in the history
…in colocated Python

When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.

This change adds a slightly more context to the exception so that the error is
more actionable.

PiperOrigin-RevId: 729172296
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Feb 20, 2025
1 parent b796847 commit 71f9764
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
29 changes: 22 additions & 7 deletions jax/experimental/colocated_python/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import collections
import functools
import io
from typing import Any, Callable, Sequence

Expand Down Expand Up @@ -77,16 +78,29 @@ def _get_cpu_device_map() -> dict[int, jax.Device]:
return cpu_device_map


def _lookup_cpu_device(
cpu_device_map: dict[int, jax.Device], device_id: int
) -> jax.Device:
"""Returns a CPU device with the given device ID."""
d = cpu_device_map.get(device_id)
if d is None:
raise ValueError(
f"Invalid device ID {device_id}. Device list must contain only CPU"
" devices."
)
return d


def _reduce_mesh(
mesh: jax.sharding.Mesh,
) -> tuple[Callable[..., jax.sharding.Mesh], Any]:
def make_mesh(
mesh_device_ids: np.ndarray, axis_names: Any
) -> jax.sharding.Mesh:
cpu_device_map = _get_cpu_device_map()
mesh_devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
mesh_device_ids
)
mesh_devices = np.vectorize(
functools.partial(_lookup_cpu_device, cpu_device_map)
)(mesh_device_ids)
return jax.sharding.Mesh(mesh_devices, axis_names)

mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices)
Expand All @@ -98,9 +112,9 @@ def _reduce_device_list(
) -> tuple[Callable[..., DeviceList], Any]:
def make_device_list(device_ids: Sequence[int]) -> DeviceList:
cpu_device_map = _get_cpu_device_map()
devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
device_ids
)
devices = np.vectorize(
functools.partial(_lookup_cpu_device, cpu_device_map)
)(device_ids)
return DeviceList(tuple(devices))

device_ids = [d.id for d in device_list]
Expand All @@ -113,7 +127,8 @@ def _reduce_single_device_sharding(

def make_single_device_sharding(device_id: int):
cpu_device_map = _get_cpu_device_map()
return jax.sharding.SingleDeviceSharding(cpu_device_map[device_id])
device = _lookup_cpu_device(cpu_device_map, device_id)
return jax.sharding.SingleDeviceSharding(device)

return make_single_device_sharding, (sharding.device_set.pop().id,)

Expand Down
21 changes: 21 additions & 0 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,27 @@ def f(x):
self.assertEqual(out_ints[0], 1002)
self.assertEqual(out_ints[1], 1003)

def testDetectInvalidMeshDevice(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())
if jax.local_devices()[0].id == cpu_devices[0].id:
self.skipTest(
"This test only works in a setup where accelerator and CPU devices"
" use different device IDs."
)

# mesh contains non-CPU devices. To be used in colocated Python, it should
# have contained CPU devices only.
mesh = jax.sharding.Mesh(jax.local_devices(), "x")
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

@colocated_python.colocated_python
def make_zero() -> jax.Array:
return jax.make_array_from_callback((), sharding, lambda _: np.array(0))

with self.assertRaisesRegex(ValueError, "Invalid device ID"):
make_zero = make_zero.specialize(devices=cpu_devices)
jax.block_until_ready(make_zero())


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 71f9764

Please sign in to comment.