From 4934eeb0682a8842eb7944356a0f42ff96f89101 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Dec 2024 21:11:52 +0000 Subject: [PATCH] added ipc --- vllm/engine/multiprocessing/__init__.py | 21 +---------- vllm/engine/multiprocessing/client.py | 19 +++++----- vllm/engine/multiprocessing/engine.py | 11 +++--- vllm/engine/multiprocessing/ipc.py | 50 +++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 vllm/engine/multiprocessing/ipc.py diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index eb3f66aa6bb17..af003bc6eff15 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,6 +1,3 @@ -import hashlib -import hmac -import secrets from dataclasses import dataclass from enum import Enum from typing import List, Mapping, Optional, Union, overload @@ -17,29 +14,13 @@ VLLM_RPC_SUCCESS_STR = "SUCCESS" -# TODO: switch to SECRET_KEY = secrets.token_bytes(16) -# and pass the SECRET_KEY to the background process. -SECRET_KEY = b"my_key" + IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" -def sign(msg: bytes) -> bytes: - """Compute the HMAC digest of msg, given signing key `key`""" - return hmac.HMAC( - SECRET_KEY, - msg, - digestmod=hashlib.sha256, - ).digest() - - -def check_signed(sig: bytes, msg: bytes) -> bool: - correct_sig = sign(msg) - return hmac.compare_digest(sig, correct_sig) - - class MQEngineDeadError(RuntimeError): pass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index c5e04c11e49e7..8eba0caca5c2b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -21,13 +21,16 @@ # yapf: disable from vllm.engine.async_llm_engine import ( build_guided_decoding_logits_processor_async) -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, check_signed, +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) +from vllm.engine.multiprocessing.ipc import (send_signed_async, + recv_signed_async) + from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -192,9 +195,7 @@ async def run_output_handler_loop(self): ENGINE_DEAD_ERROR(self._errored_with)) return - sig, message = await self.output_socket.recv_multipart(copy=False) - if not check_signed(sig, message): - raise Exception + message = await recv_signed_async(self.output_socket) request_outputs = pickle.loads(message) is_error = isinstance(request_outputs, @@ -293,7 +294,7 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, """Send an RPC request that is expecting data back.""" # Ping RPCServer with a request. - await socket.send_multipart((pickle.dumps(request), ), copy=False) + await send_signed_async(socket, pickle.dumps(request)) # Make sure the server responds in time. if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: @@ -301,7 +302,7 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, f"{VLLM_RPC_TIMEOUT} ms") # Await the data from the Server. - frame = await socket.recv(copy=False) + frame = await recv_signed_async(socket) data = pickle.loads(frame.buffer) if isinstance(data, BaseException): @@ -319,7 +320,7 @@ async def _send_one_way_rpc_request(request: RPC_REQUEST_T, if socket.closed: raise MQClientClosedError() - await socket.send_multipart((pickle.dumps(request), )) + await send_signed_async(socket, pickle.dumps(request)) async def _await_ack(self, error_message: str, socket: Socket): """Await acknowledgement that a request succeeded.""" @@ -340,7 +341,7 @@ async def _check_success(error_message: str, socket: Socket): if socket.closed: raise MQClientClosedError() - frame = await socket.recv(copy=False) + frame = await recv_signed_async(socket, error_message) response = pickle.loads(frame.buffer) # Raise error if unsuccessful @@ -628,7 +629,7 @@ async def _process_request( # 3) Send the RPCGenerateRequest to the MQLLMEngine. parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes, ) - await self.input_socket.send_multipart(parts, copy=False) + await send_signed_async(self.input_socket, parts) # 4) Stream the RequestOutputs from the output queue. Note # that the output_loop pushes RequestOutput objects to this diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 29e6b69cd9aba..c72207e4ade66 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -11,13 +11,15 @@ from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, sign, IPC_DATA_EXT, +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) +from vllm.engine.multiprocessing.ipc import send + # yapf: enable from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -311,19 +313,18 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): pass output_bytes = pickle.dumps(outputs) - sig = sign(output_bytes) - self.output_socket.send_multipart((sig, output_bytes), copy=False) + send(self.output_socket, output_bytes) def _send_healthy(self): """Send HEALTHY message to RPCClient.""" if not self.heartbeat_socket.closed: - self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + send(self.heartbeat_socket, HEALTHY_RESPONSE) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" if not self.heartbeat_socket.closed: error_bytes = pickle.dumps(error) - self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + send(self.heartbeat_socket, error_bytes) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): diff --git a/vllm/engine/multiprocessing/ipc.py b/vllm/engine/multiprocessing/ipc.py new file mode 100644 index 0000000000000..e9193ab1040d9 --- /dev/null +++ b/vllm/engine/multiprocessing/ipc.py @@ -0,0 +1,50 @@ +import hashlib +import hmac +import secrets + +import zmq +import zmq.asyncio + +# TODO: switch to SECRET_KEY = secrets.token_bytes(16) +# and pass the SECRET_KEY to the background process. +SECRET_KEY = b"my_key" + +def sign(msg: bytes) -> bytes: + """Compute the HMAC digest of msg, given signing key `key`""" + return hmac.HMAC( + SECRET_KEY, + msg, + digestmod=hashlib.sha256, + ).digest() + +def check_signed(sig: bytes, msg: bytes) -> bool: + correct_sig = sign(msg) + return hmac.compare_digest(sig, correct_sig) + +def send_signed(socket: zmq.Socket, msg: bytes): + """Send signed message to socket.""" + + sig = sign(msg) + socket.send_multipart((sig, msg), copy=False) + +def recv_signed(socket: zmq.Socket): + """Get signed message from socket.""" + + sig, msg = socket.recv_multipart(copy=False) + if not check_signed(sig, msg): + raise ValueError("Message signature is invalid.") + return msg + +async def send_signed_async(socket: zmq.asyncio.Socket, msg: bytes): + """Send signed message to asyncio socket.""" + + sig = sign(msg) + await socket.send_multipart((sig, msg), copy=False) + +async def recv_signed_async(socket: zmq.asyncio.Socket): + """Get signed message from asyncio socket.""" + + sig, msg = await socket.recv_multipart(copy=False) + if not check_signed(sig, msg): + raise ValueError("Message signature is invalid.") + return msg \ No newline at end of file