diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index 30d01081..a07d5b7d 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -26,7 +26,8 @@ def training_tester() -> LLamaTester: # ----- Tests ----- -@pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +# @pytest.mark.xfail(reason="failed to legalize operation 'stablehlo.reduce'") +@pytest.mark.skip(reason="OOMs in CI") def test_openllama3b_inference( inference_tester: LLamaTester, ):