Skip to content

Commit

Permalink
refactor chunked prefll
Browse files Browse the repository at this point in the history
  • Loading branch information
shihaobai committed Feb 12, 2025
1 parent 7b22df4 commit 0828452
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 20 deletions.
3 changes: 1 addition & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def __init__(self, kvargs):
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
enable_chunked_prefill = not kvargs.get("disable_chunked_prefill", False) # chunked prefill is default on.
enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) # chunked prefill is default on.
self.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache or enable_chunked_prefill
print(f"enable_chunked_prefill: {self.use_dynamic_prompt_cache}")
self.data_type = kvargs.get("data_type", "float16")
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test")

parser.add_argument("--chunked_prefill_size", type=int, default=8192, help="chunked prefill size")
parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--enable_chunked_prefill", action="store_true", help="whether to disable chunked prefill")
parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode")
parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode")
parser.add_argument("--simple_constraint_mode", action="store_true", help="output constraint mode")
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def normal_or_p_d_start(args):
else:
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]

if args.disable_chunked_prefill:
if not args.enable_chunked_prefill:
# 普通模式下
if args.batch_max_tokens is None:
args.batch_max_tokens = args.max_req_total_len
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def post_init(
return


class SplitFuseReq(Req):
class ChunkedPrefillReq(Req):
_pack_ = 4

def get_tuple_tokens(self, is_busy, router_max_new_token_len):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/shm_req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):

def get_req_class_type(self):
args: StartArgs = get_env_start_args()
if not args.disable_chunked_prefill:
if args.enable_chunked_prefill:
return ChunkedPrefillReq
if args.token_healing_mode:
return TokenHealingReq
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StartArgs:
router_max_wait_tokens: int = field(default=10)
use_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=256)
disable_chunked_prefill: bool = field(default=False)
enable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
simple_constraint_mode: bool = field(default=False)
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def input_tokens(self):
return batch_input_tokens

def get_batch_decode_need_tokens(self):
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # 只有在 splitfuse 模式下有意义
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # for chunked prefill

for req in self.reqs:
req_dp_index = req.sample_params.suggested_dp_index
Expand Down
19 changes: 9 additions & 10 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr

self.is_token_healing = self.args.token_healing_mode
self.chunked_prefill_size = args.chunked_prefill_size
self.enable_chunked_prefill = args.enable_chunked_prefill

self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval)
self.metric_client = MetricClient(metric_port)
Expand Down Expand Up @@ -314,16 +315,14 @@ async def _prefill_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill")
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
self.overlap_event.set()
await self.model_rpc_client.prefill(reqs)
batch.filter_out_finished_req(self.shm_req_manager)
chunked_reqs = batch.filter_out_chunked_req()
self.req_queue.pretend(chunked_reqs)
if len(batch.reqs) == 0:
# all requests are chunked requests
return
# 发个None包触发一下detokenization
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
# Prefill operation on chunkedprefill mode do not need to execute here.
# It is executed in the _decode_batch.
if not self.enable_chunked_prefill:
self.overlap_event.set()
await self.model_rpc_client.prefill(reqs)
batch.filter_out_finished_req(self.shm_req_manager)
# 发个None包触发一下detokenization
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)

logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
self.metric_client.histogram_observe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init_model(self, kvargs):
self.dp_size = kvargs.get("dp_size", 1)
self.load_way = kvargs["load_way"]
self.mode = kvargs["mode"]
self.disable_chunked_prefill = kvargs.get("disable_chunked_prefill", False)
self.enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False)
self.chunked_prefill_size = kvargs.get("chunked_prefill_size", None)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/router/req_queue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .continues_batch.impl import ContinuesBatchQueue
from .continues_batch.beam_impl import BeamContinuesBatchQueue
from .continues_batch.pd_decode_impl import ContinuesBatchQueueForPDDecode
from .chunked_prefill.impl import ChunkedPrefillQueue
from .dp_base_queue import DpQueue


Expand All @@ -10,6 +11,8 @@ def build_req_queue(args, router, dp_size: int):
queue_class = ContinuesBatchQueueForPDDecode
if args.diverse_mode:
queue_class = BeamContinuesBatchQueue
if args.enable_chunked_prefill:
queue_class = ChunkedPrefillQueue
if args.token_healing_mode:
queue_class = ContinuesBatchQueue
if args.simple_constraint_mode:
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/req_queue/base_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, args, router, dp_index, dp_size) -> None:
self.batch_max_tokens = args.batch_max_tokens
self.running_max_req_size = args.running_max_req_size # Maximum number of concurrent requests
self.chunked_prefill_size = args.chunked_prefill_size # Maximum number of tokens that can be prefilled
self.disable_chunked_prefill = args.disable_chunked_prefill
self.enable_chunked_prefill = args.enable_chunked_prefill
self.waiting_req_list: List[Req] = [] # List of queued requests
self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy
self.router_max_new_token_len = args.router_max_new_token_len
Expand Down
Empty file.
120 changes: 120 additions & 0 deletions lightllm/server/router/req_queue/chunked_prefill/impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import uuid
import numpy as np
from ...batch import Batch, Req
from lightllm.server.router.req_queue.base_queue import BaseQueue


class ChunkedPrefillQueue(BaseQueue):
def __init__(self, args, router, dp_index, dp_size) -> None:
super().__init__(args, router, dp_index, dp_size)

def _init_cache_list(self, current_batch: Batch, is_busy):
if current_batch is not None:
self.cache_len_list = [
req.get_tuple_tokens(is_busy, self.router_max_new_token_len)
for req in current_batch.reqs
if req.sample_params.suggested_dp_index == self.dp_index
]
else:
self.cache_len_list = []
return

# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])

left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)

need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
ok_token_num = (
need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)
< self.max_total_tokens
)

if not req.is_paused:
ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) <= self.running_max_req_size
else:
ok_req_num = len(self.cache_len_list) + len(self.pause_req_dict) - 1 <= self.running_max_req_size

new_batch_first_router_need_tokens += req.get_first_router_need_tokens()
# splitfuse decode ok
ok_splitfuse_decode = new_batch_first_router_need_tokens <= self.batch_max_tokens

if ok_token_num and ok_req_num and ok_splitfuse_decode:
self.router.shared_token_load.set_estimated_peak_token_count(need_max_token_num, self.dp_index)
self.router.shared_token_load.set_dynamic_max_load(
(need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index))
/ self.max_total_tokens,
self.dp_index,
)
return True, new_batch_first_router_need_tokens
else:
return False, new_batch_first_router_need_tokens

# @calculate_time(show=True, min_cost_ms=10)
def generate_new_batch(self, current_batch: Batch):

# 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。
exist_req_num = self.get_batch_dp_req_size(current_batch) + len(self.pause_req_dict)
req_is_full = exist_req_num >= self.running_max_req_size
if req_is_full:
return None

is_busy = self.is_busy()

# 得到当前batch 往前 decode 一次,需要的token量,在 splitfuse 模式下才有用,因为splitfuse
# 模式下 类似prefill 和 deocde 是在一起进行的,所以需要合并考虑历史当前Batch
new_batch_first_router_need_tokens = (
0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index]
)

self._init_cache_list(current_batch, is_busy)
can_run_list = []
abort_req_list = []
aborted_count = 0
for req in self.waiting_req_list:
if req.is_aborted and not req.is_paused:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
abort_req_list.append(req)
continue
ok_insert, new_batch_first_router_need_tokens = self._can_add_new_req(
req, is_busy, new_batch_first_router_need_tokens
)
if ok_insert:
can_run_list.append(req)
if req.is_paused:
self.pause_req_dict.pop(req.request_id)
req.is_paused = False
else:
break

if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size=self.dp_size)
for req in abort_req_list:
self.router.shm_req_manager.put_back_req_obj(req)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None

def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
is_busy = self.is_busy()
self._init_cache_list(current_batch, is_busy)
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
return (
need_max_token_num,
(need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index))
/ self.max_total_tokens,
)

0 comments on commit 0828452

Please sign in to comment.