diff --git a/docs/backend/reasoning_parser.md b/docs/backend/reasoning_parser.md new file mode 100644 index 00000000000..607acd6c012 --- /dev/null +++ b/docs/backend/reasoning_parser.md @@ -0,0 +1,127 @@ +# Reasoning Parser + +SGLang support parsing the reasoning content from reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) for convenient output processing in the downstream applications. + +Following Official [DeepSeek API design](https://api-docs.deepseek.com/guides/reasoning_model), SGLang offering reasoning content and final conclusions: + +- `reasoning_content`: The content of the CoT. +- `content`: The content of the final answer. + +## Supported Models + +Currently, SGLang supports the following reasoning models: +- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d): The reasoning content is wrapped with `` and `` tags. + +## Usage + +You need to enable the reasoning parser in the SGLang API server by setting the `--enable-reasoning` and `--reasoning-parser` options. The `--reasoning-parser` option specifies the reasoning parser to extract the reasoning content and final answer. + +```bash +python -m sglang.launch_server --host 0.0.0.0 \ +--model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \ +--enable-reasoning --reasoning-parser deepseek-r1 +``` + +### Non-streaming Request + +Make a request to the reasoning model, get the reasoning content and final answer. + +Using OpenAI python api: +```python +import openai + +client = openai.Client(base_url="http://localhost:30000/v1", api_key="None") + +response = client.chat.completions.create( + model="deepseek-r1:14b", + messages=[{"role": "user", "content": "Compute 1+3"}], + max_tokens=1024, + stream=False +) + +response.choices[0].message.reasoning_content +# 'First, I recognize that the problem requires adding the numbers 1 and 3.\n\nNext, I identify the numbers to be added, which are 1 and 3.\n\nThen, I perform the addition operation: 1 plus 3 equals 4.\n\nFinally, I conclude that the sum of 1 and 3 is 4.\n' +response.choices[0].message.content +# \n\nTo compute \\(1 + 3\\), follow these simple steps:\n\n1. **Identify the numbers to add:** \n The numbers are **1** and **3**.\n\n2. **Add the numbers together:** \n \\[\n 1 + 3 = 4\n \\]\n\n3. **Write the final answer:** \n The sum of \\(1 + 3\\) is \\(\\boxed{4}\\).' +``` + +### Streaming Request + +`reasoning_content` is available in the `delta` field of the streaming response. + +Using OpenAI python api: + +```python +# ... Initialize the client as before ... + +response = client.chat.completions.create( + model="deepseek-r1:14b", + messages=[{"role": "user", "content": "Compute 1+3"}], + max_tokens=1024, + stream=True +) +reasoning_content = "" +content = "" +for chunk in response: + if chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + elif chunk.choices[0].delta.reasoning_content: + reasoning_content += chunk.choices[0].delta.reasoning_content + +reasoning_content +# 'I need to calculate the sum of 1 and 3. \n\nFirst, I identify the numbers involved in the addition: 1 and 3.\n\nNext, I add these two numbers together to find the total.\n\nFinally, the result of the addition is 4.\n' +content +# '\n\n**Solution:**\n\nWe need to compute the sum of 1 and 3.\n\n1. **Identify the numbers to add:**\n - Number 1\n - Number 3\n\n2. **Add the numbers together:**\n \\[\n 1 + 3 = 4\n \\]\n\n3. **Final Answer:**\n \\[\n \\boxed{4}\n \\]' +``` + + +## Supported More Reasoning Models + +For future reasoning models, you can implement the reasoning parser as a subclass of `BaseReasoningParser` in `python/sglang/srt/reasoning_parser.py`. + +```python +class BaseReasoningParser: + """Base class for reasoning parser.""" + + def __init__(self): + self._buffer = "" + + def detect_and_parse(self, text: str) -> Tuple[Optional[str], Optional[str]]: + """Detect and parse the text, return reasoning_content and content.""" + raise NotImplementedError + + def parse_streaming_increment( + self, new_text: str + ) -> Tuple[Optional[str], Optional[str]]: + """Parse the new text incrementally, return reasoning_content and content.""" + raise NotImplementedError +``` + +And specify the reasoning parser for new reasoning models accordingly. + +```python +class ReasoningParser: + """Reasoning parser for different reasoning models.""" + + # Specify the reasoning parser for each reasoning model here + ReasoningParserDict: Dict[str, Type[BaseReasoningParser]] = { + "deepseek-r1": DeepSeekR1ReasoningParser + } + + def __init__(self, reasoning_parser: str): + self.parser = self.ReasoningParserDict[reasoning_parser]() + + def parse_non_stream(self, full_text: str) -> Tuple[Optional[str], Optional[str]]: + """ + Non-streaming parsing for reasoning models. + Return: reasoning_content, content + """ + return self.parser.detect_and_parse(full_text) + + def parse_stream_chunk(self, chunk_text: str): + """ + Streaming parsing for reasoning models. + Return: reasoning_content, content + """ + return self.parser.parse_streaming_increment(chunk_text) +``` diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 0556f852a32..3fc400cb5b6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -74,6 +74,7 @@ TopLogprob, UsageInfo, ) +from sglang.srt.reasoning_parser import ReasoningParser from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -1038,7 +1039,12 @@ def v1_chat_generate_request( def v1_chat_generate_response( - request, ret, to_file=False, cache_report=False, tool_call_parser=None + request, + ret, + to_file=False, + cache_report=False, + tool_call_parser=None, + reasoning_parser=None, ): choices = [] @@ -1084,11 +1090,21 @@ def v1_chat_generate_response( else: choice_logprobs = None - finish_reason = ret_item["meta_info"]["finish_reason"] - - tool_calls = None + reasoning_content = None text = ret_item["text"] + if reasoning_parser: + try: + parser = ReasoningParser(reasoning_parser) + reasoning_content, text = parser.parse_non_stream(text) + except Exception as e: + logger.error(f"Exception: {e}") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Failed to parse reasoning content", + ) + finish_reason = ret_item["meta_info"]["finish_reason"] + tool_calls = None if isinstance(request, list): tool_choice = request[idx].tool_choice tools = request[idx].tools @@ -1124,7 +1140,10 @@ def v1_chat_generate_response( "index": 0, "message": { "role": "assistant", - "content": ret_item["text"] if tool_calls is None else None, + "content": text if tool_calls is None else None, + "reasoning_content": ( + reasoning_content if tool_calls is None else None + ), "tool_calls": tool_calls, }, "logprobs": choice_logprobs, @@ -1140,7 +1159,8 @@ def v1_chat_generate_response( index=idx, message=ChatMessage( role="assistant", - content=ret_item["text"] if tool_calls is None else None, + content=text if tool_calls is None else None, + reasoning_content=reasoning_content if tool_calls is None else None, tool_calls=tool_calls, ), logprobs=choice_logprobs, @@ -1208,6 +1228,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if adapted_request.stream: parser_dict = {} + reasoning_parser_dict = {} async def generate_stream_resp(): is_firsts = {} @@ -1302,6 +1323,16 @@ async def generate_stream_resp(): delta = text[len(stream_buffer) :] new_stream_buffer = stream_buffer + delta + reasoning_content = None + if tokenizer_manager.server_args.enable_reasoning: + if index not in reasoning_parser_dict: + reasoning_parser_dict[index] = ReasoningParser( + tokenizer_manager.server_args.reasoning_parser + ) + reasoning_content, delta = reasoning_parser_dict[ + index + ].parse_stream_chunk(delta) + if request.tool_choice != "none" and request.tools: if index not in parser_dict: parser_dict[index] = FunctionCallParser( @@ -1313,11 +1344,14 @@ async def generate_stream_resp(): # parse_increment => returns (normal_text, calls) normal_text, calls = parser.parse_stream_chunk(delta) - # 1) if there's normal_text, output it as normal content + # 1) if there's normal_text, output it as normal content, the reasoning content is also included if normal_text: choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=normal_text), + delta=DeltaMessage( + content=normal_text, + reasoning_content=reasoning_content, + ), finish_reason=( finish_reason["type"] if finish_reason else "" ), @@ -1386,7 +1420,9 @@ async def generate_stream_resp(): # No tool calls => just treat this as normal text choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=delta), + delta=DeltaMessage( + content=delta, reasoning_content=reasoning_content + ), finish_reason=( finish_reason["type"] if finish_reason else "" ), @@ -1456,6 +1492,7 @@ async def generate_stream_resp(): ret, cache_report=tokenizer_manager.server_args.enable_cache_report, tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + reasoning_parser=tokenizer_manager.server_args.reasoning_parser, ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 95b34527edb..6e2ffe2015b 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -344,6 +344,7 @@ class ToolCall(BaseModel): class ChatMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -367,6 +368,7 @@ class ChatCompletionResponse(BaseModel): class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) diff --git a/python/sglang/srt/reasoning_parser.py b/python/sglang/srt/reasoning_parser.py new file mode 100644 index 00000000000..5dc3c0eb52f --- /dev/null +++ b/python/sglang/srt/reasoning_parser.py @@ -0,0 +1,93 @@ +import json +import logging +import re +from typing import Any, Dict, List, Optional, Tuple, Type + + +class BaseReasoningParser: + """Base class for reasoning parser.""" + + def __init__(self, think_start_token: str, think_end_token: str): + self._buffer = "" + self.think_start_token = think_start_token + self.think_end_token = think_end_token + self.pattern = re.compile( + rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL + ) + self.is_reasoning = True + + def parse_streaming_increment( + self, new_text: str + ) -> Tuple[Optional[str], Optional[str]]: + """Parse the new text incrementally, return reasoning_content and content.""" + # Should parse + if self.is_reasoning: + self._buffer += new_text + + # Reasoning continues + if self.think_end_token not in self._buffer: + return new_text, "" + # Reasoning ends + else: + reasoning_part = new_text.split(self.think_end_token)[0] + content_part = new_text.split(self.think_end_token)[1] + + self.is_reasoning = False + self._buffer = "" + + return reasoning_part, content_part + + else: + return "", new_text + + def detect_and_parse(self, text: str) -> Tuple[Optional[str], Optional[str]]: + """Detect and parse the text, return reasoning_content and content.""" + if self.think_end_token not in text: + return text, "" + else: + # Add the start token to the beginning of the text. + text = self.think_start_token + text + + reasoning_content = self.pattern.findall(text)[0] + content = text[ + len(self.think_start_token) + + len(reasoning_content) + + len(self.think_end_token) : + ] + + return reasoning_content, content + + +class DeepSeekR1ReasoningParser(BaseReasoningParser): + """ + DeepSeekR1 reasoning parser, which use "" and "" to detect the reasoning part. + Referring to https://github.com/deepseek-ai/DeepSeek-R1?tab=readme-ov-file#usage-recommendations~. + """ + + def __init__(self): + super().__init__("", "") + + +class ReasoningParser: + """Reasoning parser for different reasoning models.""" + + ReasoningParserDict: Dict[str, Type[BaseReasoningParser]] = { + "deepseek-r1": DeepSeekR1ReasoningParser + } + + def __init__(self, reasoning_parser: str): + self.parser = self.ReasoningParserDict[reasoning_parser]() + + def parse_non_stream(self, full_text: str) -> Tuple[Optional[str], Optional[str]]: + """ + Non-streaming parsing for reasoning models. + Return: reasoning_content, content + """ + return self.parser.detect_and_parse(full_text) + + def parse_stream_chunk(self, chunk_text: str): + """ + Streaming parsing for reasoning models. + Return: reasoning_content, content + """ + return self.parser.parse_streaming_increment(chunk_text) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fd2188dcce7..e2585e9fe22 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -95,6 +95,8 @@ class ServerArgs: api_key: Optional[str] = None file_storage_pth: str = "sglang_storage" enable_cache_report: bool = False + enable_reasoning: bool = False + reasoning_parser: Optional[str] = None # Data parallelism dp_size: int = 1 @@ -282,6 +284,12 @@ def __post_init__(self): if is_hip(): self.triton_attention_num_kv_splits = 16 + # API Related + if self.enable_reasoning and not self.reasoning_parser: + raise ValueError( + "Reasoning parser must be specified when reasoning is enabled." + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -606,6 +614,18 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", ) + parser.add_argument( + "--enable-reasoning", + action="store_true", + help="Enable the reasoning feature.", + ) + parser.add_argument( + "--reasoning-parser", + type=str, + choices=["deepseek-r1"], + default=ServerArgs.reasoning_parser, + help="Specify the parser for reasoning models, supported parsers are: deepseek-r1.", + ) # Data parallelism parser.add_argument(