From 629426f89cd03b6662b9f118a6055d094e42b874 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 21 Feb 2025 14:52:53 -0800 Subject: [PATCH] Allow casting to the same axis type PiperOrigin-RevId: 729667271 --- jax/_src/pjit.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6f8be4f1dee1..cec36739b29c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2834,9 +2834,6 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None, if not isinstance(axes, tuple): axes = (axes,) for a in axes: - if cur_mesh._name_to_type[a] == axis_type: - raise ValueError(f'Axes {a} cannot be casted to type {axis_type} since ' - f'it already is of type {axis_type}.') if (error_on_manual_to_auto_explict and cur_mesh._name_to_type[a] == mesh_lib.AxisTypes.Manual and axis_type in {mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Explicit}):