Skip to content

Commit

Permalink
added ipc
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Dec 12, 2024
1 parent 0f3cf42 commit 4934eeb
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 34 deletions.
21 changes: 1 addition & 20 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import hashlib
import hmac
import secrets
from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union, overload
Expand All @@ -17,29 +14,13 @@

VLLM_RPC_SUCCESS_STR = "SUCCESS"

# TODO: switch to SECRET_KEY = secrets.token_bytes(16)
# and pass the SECRET_KEY to the background process.
SECRET_KEY = b"my_key"

IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"


def sign(msg: bytes) -> bytes:
"""Compute the HMAC digest of msg, given signing key `key`"""
return hmac.HMAC(
SECRET_KEY,
msg,
digestmod=hashlib.sha256,
).digest()


def check_signed(sig: bytes, msg: bytes) -> bool:
correct_sig = sign(msg)
return hmac.compare_digest(sig, correct_sig)


class MQEngineDeadError(RuntimeError):
pass

Expand Down
19 changes: 10 additions & 9 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, check_signed,
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import (send_signed_async,
recv_signed_async)

from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
Expand Down Expand Up @@ -192,9 +195,7 @@ async def run_output_handler_loop(self):
ENGINE_DEAD_ERROR(self._errored_with))
return

sig, message = await self.output_socket.recv_multipart(copy=False)
if not check_signed(sig, message):
raise Exception
message = await recv_signed_async(self.output_socket)
request_outputs = pickle.loads(message)

is_error = isinstance(request_outputs,
Expand Down Expand Up @@ -293,15 +294,15 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest,
"""Send an RPC request that is expecting data back."""

# Ping RPCServer with a request.
await socket.send_multipart((pickle.dumps(request), ), copy=False)
await send_signed_async(socket, pickle.dumps(request))

# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")

# Await the data from the Server.
frame = await socket.recv(copy=False)
frame = await recv_signed_async(socket)
data = pickle.loads(frame.buffer)

if isinstance(data, BaseException):
Expand All @@ -319,7 +320,7 @@ async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
if socket.closed:
raise MQClientClosedError()

await socket.send_multipart((pickle.dumps(request), ))
await send_signed_async(socket, pickle.dumps(request))

async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
Expand All @@ -340,7 +341,7 @@ async def _check_success(error_message: str, socket: Socket):
if socket.closed:
raise MQClientClosedError()

frame = await socket.recv(copy=False)
frame = await recv_signed_async(socket, error_message)
response = pickle.loads(frame.buffer)

# Raise error if unsuccessful
Expand Down Expand Up @@ -628,7 +629,7 @@ async def _process_request(
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
await send_signed_async(self.input_socket, parts)

# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
Expand Down
11 changes: 6 additions & 5 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, sign, IPC_DATA_EXT,
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import send

# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
Expand Down Expand Up @@ -311,19 +313,18 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
pass

output_bytes = pickle.dumps(outputs)
sig = sign(output_bytes)
self.output_socket.send_multipart((sig, output_bytes), copy=False)
send(self.output_socket, output_bytes)

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
send(self.heartbeat_socket, HEALTHY_RESPONSE)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
send(self.heartbeat_socket, error_bytes)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand Down
50 changes: 50 additions & 0 deletions vllm/engine/multiprocessing/ipc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import hashlib
import hmac
import secrets

import zmq
import zmq.asyncio

# TODO: switch to SECRET_KEY = secrets.token_bytes(16)
# and pass the SECRET_KEY to the background process.
SECRET_KEY = b"my_key"

def sign(msg: bytes) -> bytes:
"""Compute the HMAC digest of msg, given signing key `key`"""
return hmac.HMAC(
SECRET_KEY,
msg,
digestmod=hashlib.sha256,
).digest()

def check_signed(sig: bytes, msg: bytes) -> bool:
correct_sig = sign(msg)
return hmac.compare_digest(sig, correct_sig)

def send_signed(socket: zmq.Socket, msg: bytes):
"""Send signed message to socket."""

sig = sign(msg)
socket.send_multipart((sig, msg), copy=False)

def recv_signed(socket: zmq.Socket):
"""Get signed message from socket."""

sig, msg = socket.recv_multipart(copy=False)
if not check_signed(sig, msg):
raise ValueError("Message signature is invalid.")
return msg

async def send_signed_async(socket: zmq.asyncio.Socket, msg: bytes):
"""Send signed message to asyncio socket."""

sig = sign(msg)
await socket.send_multipart((sig, msg), copy=False)

async def recv_signed_async(socket: zmq.asyncio.Socket):
"""Get signed message from asyncio socket."""

sig, msg = await socket.recv_multipart(copy=False)
if not check_signed(sig, msg):
raise ValueError("Message signature is invalid.")
return msg

0 comments on commit 4934eeb

Please sign in to comment.