Skip to content

Implement stop_at for mlxlm models #1318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def stream(
if seed is not None:
raise NotImplementedError("The `mlx-lm` library does not support seed.")
if stop_at is not None:
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")
stop_at = stop_at if isinstance(stop_at, list) else [stop_at]
else:
stop_at = []

generate_kwargs = {
"temp": temperature,
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
"stop_at": stop_at,
}

# Adapted from
Expand All @@ -113,8 +116,6 @@ def stream(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break
detokenizer.add_token(token)
yield detokenizer.last_segment

Expand All @@ -128,6 +129,7 @@ def generate_step(
top_p: Optional[float],
sampler: str,
logits_processor: "OutlinesLogitsProcessor",
stop_at: List[List[int]],
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
Expand Down Expand Up @@ -173,6 +175,10 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
unprocessed_input_ids = prompt
generated_ids: List[int] = []

def should_stop(token):
text = self.mlx_tokenizer.decode(generated_ids + [token])
return any(s in text for s in stop_at)

while True:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]
Expand All @@ -187,6 +193,11 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
new_token = new_token_single.item()
yield new_token, prob

if new_token == self.tokenizer.eos_token_id or (
stop_at and should_stop(new_token)
):
break

generated_ids.append(new_token)
unprocessed_input_ids = new_token_single

Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def test_mlxlm_generate():
assert len(output) > 0


@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
def test_mlxlm_generate_with_stop_at():
from outlines.generate.api import GenerationParameters, SamplingParameters

model = mlxlm(TEST_MODEL)
prompt = 'Write sentence and end with "stop":'

gen_params = GenerationParameters(max_tokens=50, stop_at="stop", seed=None)
sampling_params = SamplingParameters(sampler="greedy")

output = model.generate(prompt, gen_params, None, sampling_params)
assert "stop" in output


@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
def test_mlxlm_stream():
from outlines.generate.api import GenerationParameters, SamplingParameters
Expand Down
Loading