diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index b557b8c31..49c025f0f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -57,9 +57,8 @@ def __init__(self, kvargs): self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False) assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time" self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) - enable_chunked_prefill = not kvargs.get("disable_chunked_prefill", False) # chunked prefill is default on. + enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) # chunked prefill is default on. self.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache or enable_chunked_prefill - print(f"enable_chunked_prefill: {self.use_dynamic_prompt_cache}") self.data_type = kvargs.get("data_type", "float16") self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16) self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 5a182d855..ff44a4a67 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -145,7 +145,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size") - parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") + parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 982759634..89983a5de 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -125,7 +125,7 @@ def normal_or_p_d_start(args): else: args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] - if args.disable_chunked_prefill: + if not args.enable_chunked_prefill: # 普通模式下 if args.batch_max_tokens is None: args.batch_max_tokens = args.max_req_total_len diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 6ff955ac0..ba577e02e 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -290,7 +290,7 @@ def post_init( return -class SplitFuseReq(Req): +class ChunkedPrefillReq(Req): _pack_ = 4 def get_tuple_tokens(self, is_busy, router_max_new_token_len): diff --git a/lightllm/server/core/objs/shm_req_manager.py b/lightllm/server/core/objs/shm_req_manager.py index 5f6a154ea..6b625dcc7 100644 --- a/lightllm/server/core/objs/shm_req_manager.py +++ b/lightllm/server/core/objs/shm_req_manager.py @@ -30,7 +30,7 @@ def __init__(self): def get_req_class_type(self): args: StartArgs = get_env_start_args() - if not args.disable_chunked_prefill: + if args.enable_chunked_prefill: return ChunkedPrefillReq if args.token_healing_mode: return TokenHealingReq diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index b834d586b..36c4cc93b 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -38,7 +38,7 @@ class StartArgs: router_max_wait_tokens: int = field(default=10) use_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=256) - disable_chunked_prefill: bool = field(default=False) + enable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) simple_constraint_mode: bool = field(default=False) diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 7e87afe9d..99234ba6c 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -22,7 +22,7 @@ def input_tokens(self): return batch_input_tokens def get_batch_decode_need_tokens(self): - new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # 只有在 splitfuse 模式下有意义 + new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # for chunked prefill for req in self.reqs: req_dp_index = req.sample_params.suggested_dp_index diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 0b63c45e1..0ef828207 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -76,6 +76,7 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr self.is_token_healing = self.args.token_healing_mode self.chunked_prefill_size = args.chunked_prefill_size + self.enable_chunked_prefill = args.enable_chunked_prefill self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) @@ -314,16 +315,14 @@ async def _prefill_batch(self, batch: Batch): start_time = time.time() self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill") reqs = [r.to_router_rpc_obj() for r in batch.reqs] - self.overlap_event.set() - await self.model_rpc_client.prefill(reqs) - batch.filter_out_finished_req(self.shm_req_manager) - chunked_reqs = batch.filter_out_chunked_req() - self.req_queue.pretend(chunked_reqs) - if len(batch.reqs) == 0: - # all requests are chunked requests - return - # 发个None包触发一下detokenization - self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) + # Prefill operation on chunkedprefill mode do not need to execute here. + # It is executed in the _decode_batch. + if not self.enable_chunked_prefill: + self.overlap_event.set() + await self.model_rpc_client.prefill(reqs) + batch.filter_out_finished_req(self.shm_req_manager) + # 发个None包触发一下detokenization + self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL) logger.debug(f"Prefill Batch: {batch.simple_log()} \n") self.metric_client.histogram_observe( diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 12e5588ef..175d4adfc 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -60,7 +60,7 @@ def init_model(self, kvargs): self.dp_size = kvargs.get("dp_size", 1) self.load_way = kvargs["load_way"] self.mode = kvargs["mode"] - self.disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False) + self.enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None) self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) diff --git a/lightllm/server/router/req_queue/__init__.py b/lightllm/server/router/req_queue/__init__.py index b60c1914e..fb61f42d5 100644 --- a/lightllm/server/router/req_queue/__init__.py +++ b/lightllm/server/router/req_queue/__init__.py @@ -1,6 +1,7 @@ from .continues_batch.impl import ContinuesBatchQueue from .continues_batch.beam_impl import BeamContinuesBatchQueue from .continues_batch.pd_decode_impl import ContinuesBatchQueueForPDDecode +from .chunked_prefill.impl import ChunkedPrefillQueue from .dp_base_queue import DpQueue @@ -10,6 +11,8 @@ def build_req_queue(args, router, dp_size: int): queue_class = ContinuesBatchQueueForPDDecode if args.diverse_mode: queue_class = BeamContinuesBatchQueue + if args.enable_chunked_prefill: + queue_class = ChunkedPrefillQueue if args.token_healing_mode: queue_class = ContinuesBatchQueue if args.simple_constraint_mode: diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 2a9ac7426..9bff747e8 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -18,7 +18,7 @@ def __init__(self, args, router, dp_index, dp_size) -> None: self.batch_max_tokens = args.batch_max_tokens self.running_max_req_size = args.running_max_req_size # Maximum number of concurrent requests self.chunked_prefill_size = args.chunked_prefill_size # Maximum number of tokens that can be prefilled - self.disable_chunked_prefill = args.disable_chunked_prefill + self.enable_chunked_prefill = args.enable_chunked_prefill self.waiting_req_list: List[Req] = [] # List of queued requests self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy self.router_max_new_token_len = args.router_max_new_token_len diff --git a/lightllm/server/router/req_queue/chunked_prefill/__init__.py b/lightllm/server/router/req_queue/chunked_prefill/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py new file mode 100644 index 000000000..49e481348 --- /dev/null +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -0,0 +1,120 @@ +import uuid +import numpy as np +from ...batch import Batch, Req +from lightllm.server.router.req_queue.base_queue import BaseQueue + + +class ChunkedPrefillQueue(BaseQueue): + def __init__(self, args, router, dp_index, dp_size) -> None: + super().__init__(args, router, dp_index, dp_size) + + def _init_cache_list(self, current_batch: Batch, is_busy): + if current_batch is not None: + self.cache_len_list = [ + req.get_tuple_tokens(is_busy, self.router_max_new_token_len) + for req in current_batch.reqs + if req.sample_params.suggested_dp_index == self.dp_index + ] + else: + self.cache_len_list = [] + return + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens): + self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + ok_token_num = ( + need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) + < self.max_total_tokens + ) + + if not req.is_paused: + ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size + else: + ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) - 1 <= self.running_max_req_size + + new_batch_first_router_need_tokens += req.get_first_router_need_tokens() + # splitfuse decode ok + ok_splitfuse_decode = new_batch_first_router_need_tokens <= self.batch_max_tokens + + if ok_token_num and ok_req_num and ok_splitfuse_decode: + self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index) + self.router.shared_token_load.set_dynamic_max_load( + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) + / self.max_total_tokens, + self.dp_index, + ) + return True, new_batch_first_router_need_tokens + else: + return False, new_batch_first_router_need_tokens + + # @calculate_time(show=True, min_cost_ms=10) + def generate_new_batch(self, current_batch: Batch): + + # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 + exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict) + req_is_full = exist_req_num >= self.running_max_req_size + if req_is_full: + return None + + is_busy = self.is_busy() + + # 得到当前batch 往前 decode 一次,需要的token量,在 splitfuse 模式下才有用,因为splitfuse + # 模式下 类似prefill 和 deocde 是在一起进行的,所以需要合并考虑历史当前Batch + new_batch_first_router_need_tokens = ( + 0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index] + ) + + self._init_cache_list(current_batch, is_busy) + can_run_list = [] + abort_req_list = [] + aborted_count = 0 + for req in self.waiting_req_list: + if req.is_aborted and not req.is_paused: + # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 + aborted_count += 1 + abort_req_list.append(req) + continue + ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req( + req, is_busy, new_batch_first_router_need_tokens + ) + if ok_insert: + can_run_list.append(req) + if req.is_paused: + self.pause_req_dict.pop(req.request_id) + req.is_paused = False + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size=self.dp_size) + for req in abort_req_list: + self.router.shm_req_manager.put_back_req_obj(req) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): + is_busy = self.is_busy() + self._init_cache_list(current_batch, is_busy) + self.cache_len_list.sort(key=lambda x: -x[1]) + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + return ( + need_max_token_num, + (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) + / self.max_total_tokens, + )