From 4816d20aa43fdc4abf66c28f6690a1953d8adbe9 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 12 Dec 2024 07:51:53 -0800 Subject: [PATCH] [V1] Fix torch profiling for offline inference (#11125) Signed-off-by: Roger Wang --- examples/offline_inference_with_profiler.py | 31 +++++++++++++-------- vllm/v1/engine/core_client.py | 4 +-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py index 1f00d26808771..abcfa8e8f2f2a 100644 --- a/examples/offline_inference_with_profiler.py +++ b/examples/offline_inference_with_profiler.py @@ -1,4 +1,5 @@ import os +import time from vllm import LLM, SamplingParams @@ -15,19 +16,25 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -# Create an LLM. -llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) +if __name__ == "__main__": -llm.start_profile() + # Create an LLM. + llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -outputs = llm.generate(prompts, sampling_params) + llm.start_profile() -llm.stop_profile() + # Generate texts from the prompts. The output is a list of RequestOutput + # objects that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) -# Print the outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + llm.stop_profile() + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # Add a buffer to wait for profiler in the background process + # (in case MP is on) to finish writing profiling output. + time.sleep(10) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8eb9a27438d53..a66ae111be8c5 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -105,7 +105,7 @@ def shutdown(self): def __del__(self): self.shutdown() - async def profile(self, is_start=True) -> None: + def profile(self, is_start=True) -> None: self.engine_core.profile(is_start) @@ -212,7 +212,7 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) - async def profile(self, is_start=True) -> None: + def profile(self, is_start=True) -> None: self._send_input(EngineCoreRequestType.PROFILE, EngineCoreProfile(is_start))