diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py index 5b4901540..db1e4a8f1 100644 --- a/benchmarks/bench_processors.py +++ b/benchmarks/bench_processors.py @@ -9,6 +9,12 @@ except ImportError: pass +try: + import jax + import jax.numpy as jnp +except ImportError: + pass + def is_mlx_lm_allowed(): try: @@ -18,6 +24,14 @@ def is_mlx_lm_allowed(): return mx.metal.is_available() +def is_jax_allowed(): + try: + import jax # noqa: F401 + except ImportError: + return False + return True + + def get_mock_processor_inputs(array_library, num_tokens=30000): """ logits: (4, 30,000 ) dtype=float @@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000): input_ids = mx.random.randint( low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32 ) + elif array_library == "jax": + logits = jnp.random.uniform( + key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32 + ) + input_ids = jnp.random.randint( + key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048) + ) else: raise ValueError @@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark: params += ["mlx"] if torch.cuda.is_available(): params += ["torch_cuda"] + if is_jax_allowed(): + params += ["jax"] def setup(self, array_library): self.logits_processor = HalvingLogitsProcessor() diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 9a52abecd..eec7de121 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -20,6 +20,16 @@ def is_mlx_array_type(array_type): return issubclass(array_type, mx.array) +def is_jax_array_type(array_type): + try: + import jaxlib + except ImportError: + return False + return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance( + array_type, jaxlib.xla_extension.ArrayImpl + ) + + class OutlinesLogitsProcessor(Protocol): """ Base class for logits processors which normalizes types of logits: @@ -101,6 +111,12 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: # https://ml-explore.github.io/mlx/build/html/usage/numpy.html return torch.from_dlpack(tensor_like) + elif is_jax_array_type(type(tensor_like)): + import jax + + torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like)) + return torch_tensor + else: raise TypeError( "LogitsProcessor must be called with either np.NDArray, " @@ -129,6 +145,11 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch return mx.array(tensor.float().numpy()) + elif is_jax_array_type(target_type): + import jax + + return jax.dlpack.from_dlpack(tensor) + else: raise TypeError( f"Failed to convert torch tensors to target_type `{target_type}`" diff --git a/pyproject.toml b/pyproject.toml index 4972f09ef..294fbe4b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,9 @@ enable_incomplete_feature = ["Unpack"] [[tool.mypy.overrides]] module = [ "exllamav2.*", + "jax", + "jaxlib", + "jax.numpy", "jinja2", "jsonschema.*", "openai.*", diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py new file mode 100644 index 000000000..cd9f48278 --- /dev/null +++ b/tests/processors/test_base_processor.py @@ -0,0 +1,74 @@ +from typing import List + +import jax.numpy as jnp +import numpy as np +import pytest +import torch + +from outlines.processors.base_logits_processor import OutlinesLogitsProcessor + +arrays = { + "list": [[1.0, 2.0], [3.0, 4.0]], + "np": np.array([[1, 2], [3, 4]], dtype=np.float32), + "jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32), + "torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32), +} + +try: + import mlx.core as mx + + arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) +except ImportError: + pass + +try: + import jax.numpy as jnp + + arrays["jax"] = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) +except ImportError: + pass + + +# Mock implementation of the abstract class for testing +class MockLogitsProcessor(OutlinesLogitsProcessor): + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + # For testing purposes, let's just return logits multiplied by 2 + return logits * 2 + + +@pytest.fixture +def processor(): + """Fixture for creating an instance of the MockLogitsProcessor.""" + return MockLogitsProcessor() + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_to_torch(array_type, processor): + data = arrays[array_type] + torch_tensor = processor._to_torch(data) + assert isinstance(torch_tensor, torch.Tensor) + assert torch.allclose( + torch_tensor.cpu(), torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) + ) + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_from_torch(array_type, processor): + torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) + data = processor._from_torch(torch_tensor, type(arrays[array_type])) + assert isinstance(data, type(arrays[array_type])) + assert np.allclose(data, arrays[array_type]) + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_call(array_type, processor): + input_ids = arrays[array_type] + logits = arrays[array_type] + processed_logits = processor(input_ids, logits) + + assert isinstance(processed_logits, type(arrays[array_type])) + assert np.allclose( + np.array(processed_logits), np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) + )