Skip to content

Commit

Permalink
Import fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Feb 6, 2025
1 parent 1fc0db9 commit 148c3ae
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from .comparison import ComparisonConfig
from .graph_tester import run_graph_test, run_graph_test_with_random_inputs
from .model_tester import ModelTester, RunMode
from .multichip_tester import run_multichip_test_with_random_inputs
from .op_tester import run_op_test, run_op_test_with_random_inputs
from .utils import random_tensor, supported_dtypes, make_partition_spec
from .multichip_tester import run_multichip_test_with_random_inputs
8 changes: 4 additions & 4 deletions tests/jax/multichip/manual/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
from infra import run_multichip_test_with_random_inputs
from infra import run_multichip_test_with_random_inputs, make_partition_spec
import pytest
from tests.utils import compile_failed, make_partition_spec
from utils import compile_fail
from tests.utils import make_partition_spec


@pytest.mark.parametrize("x_shape", [(8192, 784)])
@pytest.mark.skip(reason=compile_failed("Multichip still in development"))
@pytest.mark.skip(reason=compile_fail("Multichip still in development"))
def test_all_gather(x_shape: tuple):
def fwd(batch):
act = jax.lax.all_gather(batch, "batch", axis=0, tiled=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/multichip/manual/unary_eltwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import jax
import jax.numpy as jnp
from infra import run_multichip_test_with_random_inputs, make_partition_spec
from utils import compile_fail
import pytest
from utils import compile_fail
from tests.utils import make_partition_spec


Expand Down

0 comments on commit 148c3ae

Please sign in to comment.