From 0f3cf421fed51c9a6d43cbb8338dfafcf04da5f9 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Dec 2024 19:25:38 +0000 Subject: [PATCH] stash --- examples/openai_completion_client.py | 9 ++++----- vllm/engine/multiprocessing/__init__.py | 20 ++++++++++++++++++++ vllm/engine/multiprocessing/client.py | 8 +++++--- vllm/engine/multiprocessing/engine.py | 6 ++++-- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d340..8effc00120d43 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -2,7 +2,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" +openai_api_base = "http://localhost:8001/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") @@ -14,14 +14,13 @@ model = models.data[0].id # Completion API -stream = False +stream = True completion = client.completions.create( model=model, prompt="A robot may not injure a human being", echo=False, - n=2, - stream=stream, - logprobs=3) + n=1, + stream=stream) print("Completion results:") if stream: diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 420f540d0b5f4..eb3f66aa6bb17 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,3 +1,6 @@ +import hashlib +import hmac +import secrets from dataclasses import dataclass from enum import Enum from typing import List, Mapping, Optional, Union, overload @@ -14,12 +17,29 @@ VLLM_RPC_SUCCESS_STR = "SUCCESS" +# TODO: switch to SECRET_KEY = secrets.token_bytes(16) +# and pass the SECRET_KEY to the background process. +SECRET_KEY = b"my_key" IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" +def sign(msg: bytes) -> bytes: + """Compute the HMAC digest of msg, given signing key `key`""" + return hmac.HMAC( + SECRET_KEY, + msg, + digestmod=hashlib.sha256, + ).digest() + + +def check_signed(sig: bytes, msg: bytes) -> bool: + correct_sig = sign(msg) + return hmac.compare_digest(sig, correct_sig) + + class MQEngineDeadError(RuntimeError): pass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 0a046c71e86e8..c5e04c11e49e7 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -21,7 +21,7 @@ # yapf: disable from vllm.engine.async_llm_engine import ( build_guided_decoding_logits_processor_async) -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, check_signed, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, @@ -192,8 +192,10 @@ async def run_output_handler_loop(self): ENGINE_DEAD_ERROR(self._errored_with)) return - message: Frame = await self.output_socket.recv(copy=False) - request_outputs = pickle.loads(message.buffer) + sig, message = await self.output_socket.recv_multipart(copy=False) + if not check_signed(sig, message): + raise Exception + request_outputs = pickle.loads(message) is_error = isinstance(request_outputs, (BaseException, RPCError)) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 49a90b321dac4..29e6b69cd9aba 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,5 +1,6 @@ import pickle import signal +import hmac from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -10,7 +11,7 @@ from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable -from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, sign, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, @@ -310,7 +311,8 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): pass output_bytes = pickle.dumps(outputs) - self.output_socket.send_multipart((output_bytes, ), copy=False) + sig = sign(output_bytes) + self.output_socket.send_multipart((sig, output_bytes), copy=False) def _send_healthy(self): """Send HEALTHY message to RPCClient."""