From 36078fb24760de837b95f4bea87ede00c0fd91e8 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 17 Sep 2024 16:33:53 -0700 Subject: [PATCH] fix schedule bug (#1450) --- .../sglang/srt/managers/policy_scheduler.py | 141 ++++++------------ python/sglang/srt/managers/tp_worker.py | 17 ++- 2 files changed, 59 insertions(+), 99 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index b58c0e7b3b8..0c2b21acbe7 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -119,19 +119,32 @@ def __init__( self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens - self.rem_total_tokens_ = self.rem_total_tokens - self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens + self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens + self.req_states = None self.can_run_list = [] self.new_inflight_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 + if running_batch is not None: + # Pre-remove the tokens which will be occupied by the running requests + self.rem_total_tokens -= sum( + [ + min( + (r.sampling_params.max_new_tokens - len(r.output_ids)), + CLIP_MAX_NEW_TOKENS, + ) + * self.new_token_ratio + for r in running_batch.reqs + ] + ) + def no_remaining_tokens(self): return ( self.rem_total_tokens <= 0 @@ -141,31 +154,14 @@ def no_remaining_tokens(self): if self.rem_chunk_tokens is not None else False ) - ) - - def remove_running_tokens(self, running_batch: ScheduleBatch): - self.rem_total_tokens -= sum( - [ - min( - (r.sampling_params.max_new_tokens - len(r.output_ids)), - CLIP_MAX_NEW_TOKENS, - ) - * self.new_token_ratio - for r in running_batch.reqs - ] - ) - self.rem_total_tokens_ -= sum( - [ - r.sampling_params.max_new_tokens - len(r.output_ids) - for r in running_batch.reqs - ] + or self.cur_rem_tokens <= 0 ) def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): self.rem_total_tokens -= extend_input_len + max_new_tokens - self.rem_total_tokens_ -= extend_input_len + max_new_tokens + self.cur_rem_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -173,29 +169,7 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_inflight_req_ignore_eos(self, req: Req): - truncated = req.extend_input_len > self.rem_chunk_tokens - req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) - req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] - self.can_run_list.append(req) - - self._prefill_one_req( - 0, - req.extend_input_len, - ( - min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) - if not truncated - else 0 - ), - ) - - # Return if chunked prefill not finished - return req if truncated else None - def add_inflight_req(self, req: Req): - if req.sampling_params.ignore_eos: - return self.add_inflight_req_ignore_eos(req) - truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -225,7 +199,7 @@ def _lock_node(self, last_node: TreeNode): self.rem_total_tokens += delta def add_one_req_ignore_eos(self, req: Req): - def get_req_state(r): + def add_req_state(r, insert_sort=False): new_token_ratio = ( 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio ) @@ -235,56 +209,37 @@ def get_req_state(r): tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) if tokens_left > 0: - return (tokens_left, tokens_occupied) - - return None - - # Quick Check - can_run = False - if ( - req.extend_input_len + req.sampling_params.max_new_tokens - <= self.rem_total_tokens - ): - can_run = True - - if not can_run: - if self.req_states is None: - self.req_states = [] - if self.running_batch is not None: - for r in self.running_batch.reqs: - state = get_req_state(r) - if state is not None: - self.req_states.append(state) - for r in self.can_run_list: - state = get_req_state(r) - if state is not None: - self.req_states.append(state) - state = get_req_state(req) - if state is not None: - self.req_states.append(state) - - self.req_states.sort(key=lambda x: x[0]) - else: - state = get_req_state(req) - if state is not None: - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - if tokens_left >= state[0]: - self.req_states.insert(i, state) + if not insert_sort: + self.req_states.append((tokens_left, tokens_occupied)) + else: + for i in range(len(self.req_states)): + if tokens_left <= self.req_states[i][0]: break - else: - self.req_states.append(state) - - tokens_freed = 0 - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - decode_steps = ( - self.req_states[i + 1][0] - if i + 1 < len(self.req_states) - else tokens_left - ) - bs = len(self.req_states) - i - if self.total_tokens + tokens_freed - decode_steps * bs <= 0: - return False - tokens_freed += tokens_occupied + self.req_states.insert(i, (tokens_left, tokens_occupied)) + + if self.req_states is None: + self.req_states = [] + add_req_state(req) + if self.running_batch is not None: + for r in self.running_batch.reqs: + add_req_state(r) + for r in self.can_run_list: + add_req_state(r) + self.req_states.sort(key=lambda x: x[0]) + else: + add_req_state(req, insert_sort=True) + + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied if req.extend_input_len <= self.rem_chunk_tokens: self.can_run_list.append(req) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 8053f27d0b0..fe9afc9f316 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -445,9 +445,6 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: num_mixed_running, ) - if self.running_batch is not None: - adder.remove_running_tokens(self.running_batch) - has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: self.current_inflight_req.init_next_round_input( @@ -465,9 +462,6 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) for req in self.waiting_queue: - if adder.no_remaining_tokens(): - break - req.init_next_round_input(None if prefix_computed else self.tree_cache) if ( self.lora_paths is not None and len( @@ -478,6 +472,10 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: > self.max_loras_per_batch ): break + + if adder.no_remaining_tokens(): + break + req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) if ( not res @@ -507,6 +505,11 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: else: tree_cache_hit_rate = 0.0 + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + if num_mixed_running > 0: logger.info( f"Prefill batch" @@ -515,6 +518,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" ) else: @@ -524,6 +528,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" )