Skip to content

Commit

Permalink
Merge in awesome docs from sgl-project#3859 by @ShaoZhang0115 and add…
Browse files Browse the repository at this point in the history
… unittests.
  • Loading branch information
Lucas Pickup committed Feb 28, 2025
1 parent 210fbdc commit e165bf7
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 12 deletions.
138 changes: 138 additions & 0 deletions docs/backend/reasoning_parser.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Reasoning Parser

SGLang supports parsing reasoning content our from "normal" content for reasoning models such as [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).

The contract follows the [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model) established with the release of DeepSeek-R1:

- `reasoning_content`: The content of the CoT.
- `content`: The content of the final answer.

## Supported Models

Currently, SGLang supports the following reasoning models:
- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `<think>` and `</think>` tags.

## Usage

There are two ways to enable reasoning parsing:

1) Enable the reasoning parser when starting the SGLang Server by setting the `--enable-reasoning` and `--reasoning-parser` options. The `--reasoning-parser` option specifies the reasoning parser to extract the reasoning content and final answer.

```bash
python -m sglang.launch_server --host 0.0.0.0 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
--enable-reasoning --reasoning-parser deepseek-r1
```

2) Specify on a per-request basis by setting the `separate_reasoning` body field on a `/chat/completions` request.

```bash
curl -X POST -H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Compute 1+3"}],"max_tokens":100,"model":"deepseek-r1","stream":true,"separate_reasoning":true}' http://0.0.0.0:30000/v1/chat/completions
```

There is another body param which can be set to buffer the reasoning traces to be sent in one chunk after the closing `</think>` tag, `"stream_reasoning": false`.

### Non-streaming Request

Make a request to the reasoning model, get the reasoning content and final answer.

Using OpenAI python api:
```python
import openai

client = openai.Client(base_url="http://localhost:30000/v1", api_key="None")

response = client.chat.completions.create(
model="deepseek-r1:14b",
messages=[{"role": "user", "content": "Compute 1+3"}],
max_tokens=1024,
stream=False
)

response.choices[0].message.reasoning_content
# 'First, I recognize that the problem requires adding the numbers 1 and 3.\n\nNext, I identify the numbers to be added, which are 1 and 3.\n\nThen, I perform the addition operation: 1 plus 3 equals 4.\n\nFinally, I conclude that the sum of 1 and 3 is 4.\n'
response.choices[0].message.content
# \n\nTo compute \\(1 + 3\\), follow these simple steps:\n\n1. **Identify the numbers to add:** \n The numbers are **1** and **3**.\n\n2. **Add the numbers together:** \n \\[\n 1 + 3 = 4\n \\]\n\n3. **Write the final answer:** \n The sum of \\(1 + 3\\) is \\(\\boxed{4}\\).'
```

### Streaming Request

`reasoning_content` is available in the `delta` field of the streaming response.

Using OpenAI python api:

```python
# ... Initialize the client as before ...

response = client.chat.completions.create(
model="deepseek-r1:14b",
messages=[{"role": "user", "content": "Compute 1+3"}],
max_tokens=1024,
stream=True
)
reasoning_content = ""
content = ""
for chunk in response:
if chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
elif chunk.choices[0].delta.reasoning_content:
reasoning_content += chunk.choices[0].delta.reasoning_content

reasoning_content
# 'I need to calculate the sum of 1 and 3. \n\nFirst, I identify the numbers involved in the addition: 1 and 3.\n\nNext, I add these two numbers together to find the total.\n\nFinally, the result of the addition is 4.\n'
content
# '\n\n**Solution:**\n\nWe need to compute the sum of 1 and 3.\n\n1. **Identify the numbers to add:**\n - Number 1\n - Number 3\n\n2. **Add the numbers together:**\n \\[\n 1 + 3 = 4\n \\]\n\n3. **Final Answer:**\n \\[\n \\boxed{4}\n \\]'
```


## Supporting New Reasoning Models

For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningParser` in `python/sglang/srt/reasoning_parser.py`.

```python
class BaseReasoningParser:
"""Base class for reasoning parser."""

def __init__(self):
self._buffer = ""

def detect_and_parse(self, text: str) -> Tuple[Optional[str], Optional[str]]:
"""Detect and parse the text, return reasoning_content and content."""
raise NotImplementedError

def parse_streaming_increment(
self, new_text: str
) -> Tuple[Optional[str], Optional[str]]:
"""Parse the new text incrementally, return reasoning_content and content."""
raise NotImplementedError
```

And specify the reasoning parser for new reasoning models accordingly.

```python
class ReasoningParser:
"""Reasoning parser for different reasoning models."""

# Specify the reasoning parser for each reasoning model here
ReasoningParserDict: Dict[str, Type[BaseReasoningParser]] = {
"deepseek-r1": DeepSeekR1ReasoningParser
}

def __init__(self, reasoning_parser: str):
self.parser = self.ReasoningParserDict[reasoning_parser]()

def parse_non_stream(self, full_text: str) -> Tuple[Optional[str], Optional[str]]:
"""
Non-streaming parsing for reasoning models.
Return: reasoning_content, content
"""
return self.parser.detect_and_parse(full_text)

def parse_stream_chunk(self, chunk_text: str):
"""
Streaming parsing for reasoning models.
Return: reasoning_content, content
"""
return self.parser.parse_streaming_increment(chunk_text)
```
27 changes: 19 additions & 8 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,12 @@ def v1_chat_generate_request(


def v1_chat_generate_response(
request, ret, to_file=False, cache_report=False, tool_call_parser=None
request,
ret,
to_file=False,
cache_report=False,
tool_call_parser=None,
reasoning_parser=None,
):
choices = []

Expand Down Expand Up @@ -1099,11 +1104,13 @@ def v1_chat_generate_response(
tools = request.tools
model = request.model

if request.separate_reasoning and is_reasoning_model(model):
if reasoning_parser or (
request.separate_reasoning and is_reasoning_model(model)
):
try:
parser = ReasoningParser(model, True)
parse_result = parser.parse_non_stream(text)
ret_item["text"] = (
text = (
None
if parse_result.normal_text and len(parse_result.normal_text) == 0
else parse_result.normal_text
Expand Down Expand Up @@ -1146,7 +1153,7 @@ def v1_chat_generate_response(
"index": 0,
"message": {
"role": "assistant",
"content": ret_item["text"] if tool_calls is None else None,
"content": text if tool_calls is None else None,
"tool_calls": tool_calls,
},
"logprobs": choice_logprobs,
Expand All @@ -1164,7 +1171,7 @@ def v1_chat_generate_response(
index=idx,
message=ChatMessage(
role="assistant",
content=ret_item["text"] if tool_calls is None else None,
content=text if tool_calls is None else None,
tool_calls=tool_calls,
reasoning_content=reasoning_text,
),
Expand Down Expand Up @@ -1307,8 +1314,9 @@ async def generate_stream_resp():
if is_first:
# First chunk with role
is_first = False
if request.separate_reasoning and is_reasoning_model(
request.model
if tokenizer_manager.server_args.reasoning_parser or (
request.separate_reasoning
and is_reasoning_model(request.model)
):
delta = DeltaMessage(role="assistant", reasoning_content="")
else:
Expand Down Expand Up @@ -1339,7 +1347,9 @@ async def generate_stream_resp():
delta = text[len(stream_buffer) :]
new_stream_buffer = stream_buffer + delta

if request.separate_reasoning and is_reasoning_model(request.model):
if tokenizer_manager.server_args.reasoning_parser or (
request.separate_reasoning and is_reasoning_model(request.model)
):
if index not in reasoning_parser_dict:
reasoning_parser_dict[index] = ReasoningParser(
request.model, request.stream_reasoning
Expand Down Expand Up @@ -1530,6 +1540,7 @@ async def generate_stream_resp():
ret,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
reasoning_parser=tokenizer_manager.server_args.reasoning_parser,
)

return response
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ class ToolCall(BaseModel):
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])


class ChatCompletionResponseChoice(BaseModel):
Expand All @@ -370,8 +370,8 @@ class ChatCompletionResponse(BaseModel):
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])


class ChatCompletionResponseStreamChoice(BaseModel):
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import re
from typing import Dict, Optional
from typing import Dict

REASONING_MODELS = ["deepseek-r1"]


def is_reasoning_model(model_name: str) -> bool:
"""Checks if the model is a reasoning model."""
return model_name.lower() in REASONING_MODELS
for model in REASONING_MODELS:
if re.match(f".*{model}.*", model_name, re.IGNORECASE):
return True
return False


class StreamingParseResult:
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.reasoning_parser import REASONING_MODELS
from sglang.srt.utils import (
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
Expand Down Expand Up @@ -95,6 +96,7 @@ class ServerArgs:
api_key: Optional[str] = None
file_storage_pth: str = "sglang_storage"
enable_cache_report: bool = False
reasoning_parser: Optional[str] = None

# Data parallelism
dp_size: int = 1
Expand Down Expand Up @@ -606,6 +608,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--reasoning-parser",
type=str,
choices=REASONING_MODELS,
default=ServerArgs.reasoning_parser,
help="Specify the parser for reasoning models, supported parsers are: {REASONING_MODELS}.",
)

# Data parallelism
parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"test_openai_server.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_reasoning_content.py",
"test_regex_constrained.py",
"test_release_memory_occupation.py",
"test_request_length_validation.py",
Expand Down
Loading

0 comments on commit e165bf7

Please sign in to comment.