Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reasoning parser #3859

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

ShaoZhang0115
Copy link

Motivation

Rewrite #3202

Modifications

  1. add --enable-reasoning and --reasoning-parser options for deepseek r1 series models.
  2. return reasoning_content as in official api, ref: https://api-docs.deepseek.com/zh-cn/guides/reasoning_model, in both streaming and non-streaming chat completions.
    Example:
python -m sglang.launch_server --host 0.0.0.0 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--tp 1 --enable-reasoning --reasoning-parser deepseek-r1 
curl --location --request POST 'http: //localhost:30000/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer YOUR_API_KEY' \
--data '{
    "model": "default",
    "messages": [
        {
            "role": "user",
            "content": "Calculate 1 + 3"
        }
    ],
    "stream": false
}'

Get response:

{
    "id": "53de20f7f1244195826e7b52011c37a4",
    "object": "chat.completion",
    "created": 1740507802,
    "model": "default",
    "choices": [
        {
            "index": 0,
            "message": {
                "role": "assistant",
                "content": "\n\n**Solution:**\n\nTo calculate \\(1 + 3\\), follow these easy steps:\n\n1. **Identify the numbers to add:**  \n   You have the number **1** and the number **3**.\n\n2. **Add the numbers together:**  \n   \\[\n   1 + 3 = 4\n   \\]\n\n3. **Final Answer:**  \n   \\[\n   \\boxed{4}\n   \\]",
                "reasoning_content": "To calculate the sum of 1 and 3, I will begin by identifying the two numbers involved in the addition. The first number is 1, and the second number is 3.\n\nNext, I will add these two numbers together. Adding 1 and 3 gives me a total of 4.\n\nTherefore, the result of 1 plus 3 is 4.\n",
                "tool_calls": null
            },
            "logprobs": null,
            "finish_reason": "stop",
            "matched_stop": 151643
        }
    ],
    "usage": {
        "prompt_tokens": 11,
        "total_tokens": 179,
        "completion_tokens": 168,
        "prompt_tokens_details": null
    }
}

Docs with be updated as soon as possible.

Checklist

Comment on lines 32 to 33
self.think_start_token = "<think>"
self.think_end_token = "</think>"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend this to all reasoning models? Not just dpsk R1. There might be different thinking tokens.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think different reasoning models need different parers, and I add docs for it.

@xihuai18
Copy link

  • Add Docs
  • Test with streaming and non-streaming cases, with truncated or non-truncated max-tokens for reasoning.

@xihuai18
Copy link

However, I can not pass my tests with --enable-torch-compile, which is confusing.

@xihuai18
Copy link

However, I can not pass my tests with --enable-torch-compile, which is confusing.

possible related issue: #3730 (comment)

self.think_start_token = think_start_token
self.think_end_token = think_end_token
self.pattern = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most recent tokenizer hardcodes the opening <think> tag: https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f

This means the text coming back from inference won't include <think>, this is why I updated #3202 to assume the model is reasoning until </think> is seen, it also strips out <think> to handle the old chat template.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tot0

The PR added the start token if it is missing:

            # Add the start token to the beginning of the text.
            text = self.think_start_token + text

You can see it in detect_and_parse

```bash
python -m sglang.launch_server --host 0.0.0.0 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
--enable-reasoning --reasoning-parser deepseek-r1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate the docs I was too lazy to add!

Would you consider also supporting the separate_reasoning contract? For my use case we want inference users to be able to control whether reasoning_content is separated, rather than set it as default behavior on sglang launch, which I understand some sglang users will want to do.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean add a separate_reasoning parameter in sending requests?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating reasoning and non-reasoning outputs is super useful, and would love for that to be a toggle rather than always on or always off.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to merge the great changes from this PR in #3202 to try and get best of both worlds?
Or visa versa, @ShaoZhang0115?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated #3202 to combine functionality form this PR, and added some unittests.

@maximegmd
Copy link

How does that work with grammars? Does the grammar kick-in only after the reasoning parser?

@tot0
Copy link

tot0 commented Feb 27, 2025

How does that work with grammars? Does the grammar kick-in only after the reasoning parser?

Have a similar question about this as well, though I don't think it's specific to this PR or #3202 , the reasoning parsers (and tool parsers) operate at the level of text coming out of the underlying engine to the API layer.
As far as I can tell (after taking a look at #3298) the grammar engine choice is passed down to the underlying engine via sampling_params, and enforcement is done there. This suggests that to not enforce grammar constraints until reasoning models are done reasoning would involve exposing the knowledge the ReasoningParser has about the "end reasoning" token (</think> for R1) to the underlying engine.

cc @JC1DA and @mmoskal

@maximegmd
Copy link

How does that work with grammars? Does the grammar kick-in only after the reasoning parser?

Have a similar question about this as well, though I don't think it's specific to this PR or #3202 , the reasoning parsers (and tool parsers) operate at the level of text coming out of the underlying engine to the API layer. As far as I can tell (after taking a look at #3298) the grammar engine choice is passed down to the underlying engine via sampling_params, and enforcement is done there. This suggests that to not enforce grammar constraints until reasoning models are done reasoning would involve exposing the knowledge the ReasoningParser has about the "end reasoning" token (</think> for R1) to the underlying engine.

Ideally we would be able to pass a grammar for reasoning and a grammar for content, but I believe the default grammar behavior should apply only to the content.

for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this functioning correctly now? When I test the feature for the vllm, it triggers an error from the OpenAI Python client.

Please note that it is not compatible with the OpenAI Python client library. You can use the requests library to make streaming requests.

tot0 pushed a commit to tot0/sglang that referenced this pull request Feb 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants