From 80c751e7f68ade3d4c6391a0f3fce9ce970ddad0 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 3 Jan 2025 12:25:38 -0500 Subject: [PATCH] [V1] Simplify Shutdown (#11659) --- tests/v1/engine/test_engine_core_client.py | 6 --- vllm/entrypoints/llm.py | 5 --- vllm/v1/engine/async_llm.py | 3 -- vllm/v1/engine/core.py | 1 - vllm/v1/engine/core_client.py | 34 ++++++++-------- vllm/v1/engine/llm_engine.py | 7 ---- vllm/v1/utils.py | 46 +++++++++++----------- 7 files changed, 42 insertions(+), 60 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 729975e4ea8c4..20d4e6f63b339 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -142,9 +142,6 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.abort_requests([request.request_id]) - # Shutdown the client. - client.shutdown() - @pytest.mark.asyncio async def test_engine_core_client_asyncio(monkeypatch): @@ -200,6 +197,3 @@ async def test_engine_core_client_asyncio(monkeypatch): else: assert len(outputs[req_id]) == MAX_TOKENS, ( f"{len(outputs[req_id])=}, {MAX_TOKENS=}") - - # Shutdown the client. - client.shutdown() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fadf297e9f6aa..7c0de3b3e5481 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -232,11 +232,6 @@ def __init__( self.request_counter = Counter() - def __del__(self): - if hasattr(self, 'llm_engine') and self.llm_engine and hasattr( - self.llm_engine, "shutdown"): - self.llm_engine.shutdown() - @staticmethod def get_engine_class() -> Type[LLMEngine]: if envs.VLLM_USE_V1: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3f097ca7f439c..ff7a0c28dd91a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -103,9 +103,6 @@ def sigquit_handler(signum, frame): self.output_handler: Optional[asyncio.Task] = None - def __del__(self): - self.shutdown() - @classmethod def from_engine_args( cls, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5840541d774ba..13a50a4f855e2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -203,7 +203,6 @@ def signal_handler(signum, frame): finally: if engine_core is not None: engine_core.shutdown() - engine_core = None def run_busy_loop(self): """Core busy loop of the EngineCore.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 3293205e110af..e009f3448bf69 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Type +import weakref +from abc import ABC, abstractmethod +from typing import List, Type import msgspec import zmq @@ -18,7 +20,7 @@ logger = init_logger(__name__) -class EngineCoreClient: +class EngineCoreClient(ABC): """ EngineCoreClient: subclasses handle different methods for pushing and pulling from the EngineCore for asyncio / multiprocessing. @@ -52,8 +54,9 @@ def make_client( return InprocClient(vllm_config, executor_class, log_stats) + @abstractmethod def shutdown(self): - pass + ... def get_output(self) -> List[EngineCoreOutput]: raise NotImplementedError @@ -107,9 +110,6 @@ def abort_requests(self, request_ids: List[str]) -> None: def shutdown(self): self.engine_core.shutdown() - def __del__(self): - self.shutdown() - def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) @@ -139,10 +139,14 @@ def __init__( self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) # ZMQ setup. - if asyncio_mode: - self.ctx = zmq.asyncio.Context() - else: - self.ctx = zmq.Context() # type: ignore[attr-defined] + self.ctx = ( + zmq.asyncio.Context() # type: ignore[attr-defined] + if asyncio_mode else zmq.Context()) # type: ignore[attr-defined] + + # Note(rob): shutdown function cannot be a bound method, + # else the gc cannot collect the object. + self._finalizer = weakref.finalize(self, lambda x: x.destroy(linger=0), + self.ctx) # Paths and sockets for IPC. output_path = get_open_zmq_ipc_path() @@ -153,7 +157,6 @@ def __init__( zmq.constants.PUSH) # Start EngineCore in background process. - self.proc_handle: Optional[BackgroundProcHandle] self.proc_handle = BackgroundProcHandle( input_path=input_path, output_path=output_path, @@ -166,12 +169,11 @@ def __init__( }) def shutdown(self): - # Shut down the zmq context. - self.ctx.destroy(linger=0) - - if hasattr(self, "proc_handle") and self.proc_handle: + """Clean up background resources.""" + if hasattr(self, "proc_handle"): self.proc_handle.shutdown() - self.proc_handle = None + + self._finalizer() class SyncMPClient(MPClient): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a19109559eabf..1f49de67d7493 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -205,10 +205,3 @@ def get_tokenizer_group( f"found type: {type(tokenizer_group)}") return tokenizer_group - - def __del__(self): - self.shutdown() - - def shutdown(self): - if engine_core := getattr(self, "engine_core", None): - engine_core.shutdown() diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 19e0dd17237c9..b0a7affbebb7e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,3 +1,4 @@ +import multiprocessing import os import weakref from collections.abc import Sequence @@ -91,8 +92,6 @@ def __init__( target_fn: Callable, process_kwargs: Dict[Any, Any], ): - self._finalizer = weakref.finalize(self, self.shutdown) - context = get_mp_context() reader, writer = context.Pipe(duplex=False) @@ -102,11 +101,11 @@ def __init__( process_kwargs["ready_pipe"] = writer process_kwargs["input_path"] = input_path process_kwargs["output_path"] = output_path - self.input_path = input_path - self.output_path = output_path - # Run Detokenizer busy loop in background process. + # Run busy loop in background process. self.proc = context.Process(target=target_fn, kwargs=process_kwargs) + self._finalizer = weakref.finalize(self, shutdown, self.proc, + input_path, output_path) self.proc.start() # Wait for startup. @@ -114,21 +113,24 @@ def __init__( raise RuntimeError(f"{process_name} initialization failed. " "See root cause above.") - def __del__(self): - self.shutdown() - def shutdown(self): - # Shutdown the process if needed. - if hasattr(self, "proc") and self.proc.is_alive(): - self.proc.terminate() - self.proc.join(5) - - if self.proc.is_alive(): - kill_process_tree(self.proc.pid) - - # Remove zmq ipc socket files - ipc_sockets = [self.output_path, self.input_path] - for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + self._finalizer() + + +# Note(rob): shutdown function cannot be a bound method, +# else the gc cannot collect the object. +def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): + # Shutdown the process. + if proc.is_alive(): + proc.terminate() + proc.join(5) + + if proc.is_alive(): + kill_process_tree(proc.pid) + + # Remove zmq ipc socket files. + ipc_sockets = [output_path, input_path] + for ipc_socket in ipc_sockets: + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file)