Skip to content

Commit

Permalink
Support stop_token_ids in sglang API (#1092)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 15, 2024
1 parent 1c2b5f5 commit 73cf683
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 12 deletions.
6 changes: 6 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand Down Expand Up @@ -98,6 +99,7 @@ def gen(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand All @@ -117,6 +119,7 @@ def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -132,6 +135,7 @@ def gen_int(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand All @@ -151,6 +155,7 @@ def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -166,6 +171,7 @@ def gen_string(
name,
max_tokens,
stop,
stop_token_ids,
temperature,
top_p,
top_k,
Expand Down
6 changes: 4 additions & 2 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
SglConstantText,
SglExpr,
SglExprList,
SglFunction,
SglGen,
SglImage,
SglRoleBegin,
Expand Down Expand Up @@ -181,8 +180,10 @@ def __init__(
num_api_spec_tokens=None,
use_thread=True,
):
from sglang.lang.backend.base_backend import BaseBackend

self.sid = uuid.uuid4().hex
self.backend = backend
self.backend: BaseBackend = backend
self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para
self.stream = stream
Expand Down Expand Up @@ -658,6 +659,7 @@ def _resolve_sampling_params(self, sampling_params):
for item in [
"max_new_tokens",
"stop",
"stop_token_ids",
"temperature",
"top_p",
"top_k",
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class SglSamplingParams:
max_new_tokens: int = 128
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
Expand All @@ -37,6 +38,7 @@ def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.stop_token_ids,
self.temperature,
self.top_p,
self.top_k,
Expand Down Expand Up @@ -108,6 +110,7 @@ def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand Down Expand Up @@ -141,7 +144,8 @@ def run(
self,
*args,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
stop: Union[str, List[str]] = [],
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand All @@ -161,6 +165,7 @@ def run(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand All @@ -181,6 +186,7 @@ def run_batch(
*,
max_new_tokens: int = 128,
stop: Union[str, List[str]] = (),
stop_token_ids: Optional[List[int]] = [],
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand Down Expand Up @@ -218,6 +224,7 @@ def run_batch(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down Expand Up @@ -397,6 +404,7 @@ def __init__(
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
Expand All @@ -416,6 +424,7 @@ def __init__(
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
top_k=top_k,
Expand Down
10 changes: 6 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,12 @@ def check_finished(self):
return

last_token_id = self.output_ids[-1]
if self.tokenizer is None:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
else:
matched_eos = last_token_id == self.tokenizer.eos_token_id

matched_eos = last_token_id in self.sampling_params.stop_token_ids

if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id

if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/test/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,16 @@ def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR

s += "Generate a JSON object to describe the basic city information of Paris.\n"
s += "Here are the JSON object:\n"

# NOTE: we recommend using dtype gen or whole regex string to control the output

with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
s += "}"

ret = decode_json.run(temperature=0.0)
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_moe_serving_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_default_without_radix_cache(self):

if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
# A100 (PCIE) performance
assert res["output_throughput"] > 940
assert res["output_throughput"] > 930

def test_default_with_chunked_prefill(self):
res = self.run_test(
Expand Down

0 comments on commit 73cf683

Please sign in to comment.