diff --git a/benchmarks/benchmark_throughput_async.py b/benchmarks/benchmark_throughput_async.py new file mode 100644 index 0000000000000..0b9c2e16a3706 --- /dev/null +++ b/benchmarks/benchmark_throughput_async.py @@ -0,0 +1,479 @@ +"""Benchmark offline inference throughput.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.entrypoints.openai.api_server import build_async_engine_client_from_engine_args +from vllm.utils import merge_async_iterators +from vllm.engine.arg_utils import EngineArgs, AsyncEngineArgs +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.utils import FlexibleArgumentParser + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +async def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, +) -> float: + from vllm import LLM, SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + engine_use_ray=False, + disable_log_requests=True, + ) + + decoupled = True + + async with build_async_engine_client_from_engine_args(engine_args, + not decoupled) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[Tuple[str, int, int]], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [prompt for prompt, _, _ in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + coro = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) + + elapsed_time = uvloop.run(coro) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') + parser.add_argument( + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") + parser.add_argument("--use-v2-block-manager", + action='store_true', + help="Enable block manager v2.") + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="Enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', + 'bitsandbytes' + ], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559a..e99e4bd951089 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -96,6 +96,22 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[Optional[AsyncEngineClient]]: """ Create AsyncEngineClient, either: - in-process using the AsyncLLMEngine Directly @@ -104,22 +120,21 @@ async def build_async_engine_client( Returns the Client or None if the creation failed. """ - # Context manager to handle async_engine_client lifecycle - # Ensures everything is shutdown and cleaned up on error/exit - global engine_args - engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler global async_engine_client # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model, args.trust_remote_code, - args.quantization) - or args.disable_frontend_multiprocessing): + if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, + engine_args.quantization) + or disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - yield async_engine_client + try: + yield async_engine_client + finally: + async_engine_client.shutdown_background_loop() + async_engine_client = None #TODO return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -192,6 +207,8 @@ async def build_async_engine_client( from prometheus_client import multiprocess multiprocess.mark_process_dead(rpc_server_process.pid) + async_engine_client = None #TODO + router = APIRouter() diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a472e12e8ca48..c457555c54b9c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,11 +1,13 @@ import asyncio +import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Mapping, Optional +from typing import Any, AsyncGenerator, Iterator, Mapping, Optional from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -115,18 +117,21 @@ def __init__(self, rpc_path: str): self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.to_rpc_server.bind(rpc_path) # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server: Socket = self.context.socket( + zmq.constants.ROUTER) self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.from_api_server.bind(INPROC_PROXY_PATH) # Asyncio background task for the proxy. - self.proxy_task = asyncio.create_task( + self.proxy_in_task = asyncio.create_task( self.run_proxy(self.from_api_server, self.to_rpc_server)) + self.proxy_out_task = asyncio.create_task( + self.run_proxy(self.to_rpc_server, self.from_api_server)) # Since we open 1 inproc socket per request, we have a hard cap on # the number of requests that can run in vLLM w. frontend @@ -136,20 +141,11 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 - async def run_proxy(self, socket_from, socket_to): + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" - poller = zmq.asyncio.Poller() - poller.register(socket_from, zmq.constants.POLLIN) - poller.register(socket_to, zmq.constants.POLLIN) while True: - events_lst = await poller.poll() - events = dict(events_lst) - if socket_from in events: - identity, msg = await socket_from.recv_multipart() - await socket_to.send_multipart([identity, msg]) - if socket_to in events: - identity, msg = await socket_to.recv_multipart() - await socket_from.send_multipart([identity, msg]) + frames = await socket_from.recv_multipart(copy=False) + await socket_to.send_multipart(frames, copy=False) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -180,7 +176,7 @@ def close(self): self.context.destroy() @contextmanager - def to_proxy_socket(self): + def to_proxy_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -208,7 +204,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: @@ -216,7 +213,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, f"{self._data_timeout} ms") # Await the data from the Server. - data = cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) if isinstance(data, Exception): # Re-raise exceptions returned by the server @@ -234,23 +232,22 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[zmq.asyncio.Socket] = None): + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE): + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), )) if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " f"{self._data_timeout} ms") - return cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + return pickle.loads(frame.buffer) # Make a new socket connection. if socket is None: @@ -386,21 +383,19 @@ async def generate( try: with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) - ]) + await socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) # Stream back the results from the RPC Server. while not finished: - message = await socket.recv() - request_output = cloudpickle.loads(message) + message = await socket.recv(copy=False) + request_output = pickle.loads(message.buffer) if isinstance(request_output, Exception): # On exception, check if the server is still healthy @@ -424,9 +419,7 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, - socket: Optional[zmq.asyncio.Socket] = None - ) -> None: + async def check_health(self, socket: Optional[Socket] = None) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( @@ -451,4 +444,4 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 738d12bbef051..d0d52e9fd9c1c 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,4 +1,5 @@ import asyncio +import pickle import signal from typing import Any, Coroutine, Union @@ -7,6 +8,7 @@ import zmq import zmq.asyncio from typing_extensions import Never +from zmq.asyncio import Socket from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -35,7 +37,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.DEALER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) self.socket.connect(rpc_path) @@ -63,30 +65,31 @@ async def get_config(self, identity, request): else: raise ValueError("Unknown Config Request: %s", request) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(config)]) + await self.socket.send_multipart((identity, pickle.dumps(config)), + copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" tracing_flag = await self.engine.is_tracing_enabled() await self.socket.send_multipart( - [identity, cloudpickle.dumps(tracing_flag)]) + (identity, pickle.dumps(tracing_flag))) async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" @@ -96,7 +99,7 @@ async def abort(self, identity, request: RPCAbortRequest): result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR except Exception as e: result = e - await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) + await self.socket.send_multipart((identity, pickle.dumps(result))) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -110,39 +113,41 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): async for request_output in results_generator: await self.socket.send_multipart( - [identity, cloudpickle.dumps(request_output)]) + (identity, pickle.dumps(request_output)), copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def start_profile(self, identity): logger.info("Starting profiler...") await self.engine.start_profile() logger.info("Profiler started.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) async def stop_profile(self, identity): logger.info("Stopping profiler...") await self.engine.stop_profile() logger.info("Profiler stopped.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: