Skip to content

Commit

Permalink
Move scheduler code from tp_worker.py to scheduler.py (#1538)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 30, 2024
1 parent acaffd2 commit f86c1e6
Show file tree
Hide file tree
Showing 8 changed files with 933 additions and 870 deletions.
16 changes: 12 additions & 4 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert len(input_ids[i]) > bench_args.cut_len

tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
Expand Down Expand Up @@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):

reqs = []
for i in range(len(input_ids)):
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
Expand Down
11 changes: 3 additions & 8 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
processes (TokenizerManager, DetokenizerManager, Controller).
"""

import copy
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -53,12 +52,12 @@ class GenerateReqInput:
stream: bool = False
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None

is_single: bool = True

# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None

# Whether it is a single request or a batch request
is_single: bool = True

def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
Expand Down Expand Up @@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]

def __post_init__(self):
# deepcopy meta_info to avoid modification in place
self.meta_info = copy.deepcopy(self.meta_info)


@dataclass
class BatchStrOut:
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
rid: str,
origin_input_text: str,
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
):
# Input and output info
Expand All @@ -152,6 +154,8 @@ def __init__(
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids

self.sampling_params = sampling_params
self.lora_path = lora_path

# Memory info
Expand All @@ -160,6 +164,7 @@ def __init__(
# Check finish
self.tokenizer = None
self.finished_reason = None
self.stream = False

# For incremental decoding
# ----- | --------- read_ids -------|
Expand Down Expand Up @@ -187,10 +192,6 @@ def __init__(
self.extend_input_len = 0
self.last_node = None

# Sampling parameters
self.sampling_params = None
self.stream = False

# Logprobs (arguments)
self.return_logprob = False
self.logprob_start_len = 0
Expand Down
Loading

0 comments on commit f86c1e6

Please sign in to comment.