Skip to content

Commit

Permalink
[V1] Simplify Shutdown (vllm-project#11659)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat authored Jan 3, 2025
1 parent e1a5c2f commit 80c751e
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 60 deletions.
6 changes: 0 additions & 6 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):

client.abort_requests([request.request_id])

# Shutdown the client.
client.shutdown()


@pytest.mark.asyncio
async def test_engine_core_client_asyncio(monkeypatch):
Expand Down Expand Up @@ -200,6 +197,3 @@ async def test_engine_core_client_asyncio(monkeypatch):
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")

# Shutdown the client.
client.shutdown()
5 changes: 0 additions & 5 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,6 @@ def __init__(

self.request_counter = Counter()

def __del__(self):
if hasattr(self, 'llm_engine') and self.llm_engine and hasattr(
self.llm_engine, "shutdown"):
self.llm_engine.shutdown()

@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def sigquit_handler(signum, frame):

self.output_handler: Optional[asyncio.Task] = None

def __del__(self):
self.shutdown()

@classmethod
def from_engine_args(
cls,
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def signal_handler(signum, frame):
finally:
if engine_core is not None:
engine_core.shutdown()
engine_core = None

def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
Expand Down
34 changes: 18 additions & 16 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Type
import weakref
from abc import ABC, abstractmethod
from typing import List, Type

import msgspec
import zmq
Expand All @@ -18,7 +20,7 @@
logger = init_logger(__name__)


class EngineCoreClient:
class EngineCoreClient(ABC):
"""
EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing.
Expand Down Expand Up @@ -52,8 +54,9 @@ def make_client(

return InprocClient(vllm_config, executor_class, log_stats)

@abstractmethod
def shutdown(self):
pass
...

def get_output(self) -> List[EngineCoreOutput]:
raise NotImplementedError
Expand Down Expand Up @@ -107,9 +110,6 @@ def abort_requests(self, request_ids: List[str]) -> None:
def shutdown(self):
self.engine_core.shutdown()

def __del__(self):
self.shutdown()

def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)

Expand Down Expand Up @@ -139,10 +139,14 @@ def __init__(
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)

# ZMQ setup.
if asyncio_mode:
self.ctx = zmq.asyncio.Context()
else:
self.ctx = zmq.Context() # type: ignore[attr-defined]
self.ctx = (
zmq.asyncio.Context() # type: ignore[attr-defined]
if asyncio_mode else zmq.Context()) # type: ignore[attr-defined]

# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
self._finalizer = weakref.finalize(self, lambda x: x.destroy(linger=0),
self.ctx)

# Paths and sockets for IPC.
output_path = get_open_zmq_ipc_path()
Expand All @@ -153,7 +157,6 @@ def __init__(
zmq.constants.PUSH)

# Start EngineCore in background process.
self.proc_handle: Optional[BackgroundProcHandle]
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
Expand All @@ -166,12 +169,11 @@ def __init__(
})

def shutdown(self):
# Shut down the zmq context.
self.ctx.destroy(linger=0)

if hasattr(self, "proc_handle") and self.proc_handle:
"""Clean up background resources."""
if hasattr(self, "proc_handle"):
self.proc_handle.shutdown()
self.proc_handle = None

self._finalizer()


class SyncMPClient(MPClient):
Expand Down
7 changes: 0 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,3 @@ def get_tokenizer_group(
f"found type: {type(tokenizer_group)}")

return tokenizer_group

def __del__(self):
self.shutdown()

def shutdown(self):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
46 changes: 24 additions & 22 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
import weakref
from collections.abc import Sequence
Expand Down Expand Up @@ -91,8 +92,6 @@ def __init__(
target_fn: Callable,
process_kwargs: Dict[Any, Any],
):
self._finalizer = weakref.finalize(self, self.shutdown)

context = get_mp_context()
reader, writer = context.Pipe(duplex=False)

Expand All @@ -102,33 +101,36 @@ def __init__(
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
self.input_path = input_path
self.output_path = output_path

# Run Detokenizer busy loop in background process.
# Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start()

# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.")

def __del__(self):
self.shutdown()

def shutdown(self):
# Shutdown the process if needed.
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
self.proc.join(5)

if self.proc.is_alive():
kill_process_tree(self.proc.pid)

# Remove zmq ipc socket files
ipc_sockets = [self.output_path, self.input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
self._finalizer()


# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
# Shutdown the process.
if proc.is_alive():
proc.terminate()
proc.join(5)

if proc.is_alive():
kill_process_tree(proc.pid)

# Remove zmq ipc socket files.
ipc_sockets = [output_path, input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)

0 comments on commit 80c751e

Please sign in to comment.