diff --git a/python/sglang/api.py b/python/sglang/api.py index 2242b4a4c6a..887ffce76e6 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -62,6 +62,7 @@ def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -98,6 +99,7 @@ def gen( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, @@ -117,6 +119,7 @@ def gen_int( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -132,6 +135,7 @@ def gen_int( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, @@ -151,6 +155,7 @@ def gen_string( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -166,6 +171,7 @@ def gen_string( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index cf53fac3035..844c9d062b8 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -20,7 +20,6 @@ SglConstantText, SglExpr, SglExprList, - SglFunction, SglGen, SglImage, SglRoleBegin, @@ -181,8 +180,10 @@ def __init__( num_api_spec_tokens=None, use_thread=True, ): + from sglang.lang.backend.base_backend import BaseBackend + self.sid = uuid.uuid4().hex - self.backend = backend + self.backend: BaseBackend = backend self.arguments: Dict[str, Any] = arguments self.default_sampling_para = default_sampling_para self.stream = stream @@ -658,6 +659,7 @@ def _resolve_sampling_params(self, sampling_params): for item in [ "max_new_tokens", "stop", + "stop_token_ids", "temperature", "top_p", "top_k", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 0166b86870f..9db5f2719ea 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -18,6 +18,7 @@ class SglSamplingParams: max_new_tokens: int = 128 stop: Union[str, List[str]] = () + stop_token_ids: Optional[List[int]] = () temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 # -1 means disable @@ -37,6 +38,7 @@ def clone(self): return SglSamplingParams( self.max_new_tokens, self.stop, + self.stop_token_ids, self.temperature, self.top_p, self.top_k, @@ -108,6 +110,7 @@ def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, "stop": self.stop, + "stop_token_ids": self.stop_token_ids, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -141,7 +144,8 @@ def run( self, *args, max_new_tokens: int = 128, - stop: Union[str, List[str]] = (), + stop: Union[str, List[str]] = [], + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -161,6 +165,7 @@ def run( default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, @@ -181,6 +186,7 @@ def run_batch( *, max_new_tokens: int = 128, stop: Union[str, List[str]] = (), + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -218,6 +224,7 @@ def run_batch( default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, @@ -397,6 +404,7 @@ def __init__( name: Optional[str] = None, max_new_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -416,6 +424,7 @@ def __init__( self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9037f5a6eac..9e86c9b1880 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -235,10 +235,12 @@ def check_finished(self): return last_token_id = self.output_ids[-1] - if self.tokenizer is None: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - else: - matched_eos = last_token_id == self.tokenizer.eos_token_id + + matched_eos = last_token_id in self.sampling_params.stop_token_ids + + if self.tokenizer is not None: + matched_eos |= last_token_id == self.tokenizer.eos_token_id + if matched_eos and not self.sampling_params.ignore_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 6e39f0aa995..ce402558550 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -106,13 +106,16 @@ def decode_json(s): from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR s += "Generate a JSON object to describe the basic city information of Paris.\n" + s += "Here are the JSON object:\n" + + # NOTE: we recommend using dtype gen or whole regex string to control the output with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n" - s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" + s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n" + s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n" s += "}" ret = decode_json.run(temperature=0.0) diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 713eba7abb8..80b445f490f 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -84,7 +84,7 @@ def test_default_without_radix_cache(self): if os.getenv("SGLANG_IS_IN_CI", "false") == "true": # A100 (PCIE) performance - assert res["output_throughput"] > 940 + assert res["output_throughput"] > 930 def test_default_with_chunked_prefill(self): res = self.run_test(