Skip to content

Commit

Permalink
Allow casting to the same axis type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729667271
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 21, 2025
1 parent 6e83de5 commit 629426f
Showing 1 changed file with 0 additions and 3 deletions.
3 changes: 0 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,9 +2834,6 @@ def _get_new_mesh(axes: str | tuple[str, ...] | None,
if not isinstance(axes, tuple):
axes = (axes,)
for a in axes:
if cur_mesh._name_to_type[a] == axis_type:
raise ValueError(f'Axes {a} cannot be casted to type {axis_type} since '
f'it already is of type {axis_type}.')
if (error_on_manual_to_auto_explict and
cur_mesh._name_to_type[a] == mesh_lib.AxisTypes.Manual and
axis_type in {mesh_lib.AxisTypes.Auto, mesh_lib.AxisTypes.Explicit}):
Expand Down

0 comments on commit 629426f

Please sign in to comment.