diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4aceb19b50776..5591893d267a2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -8,6 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -15,7 +16,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -224,8 +225,7 @@ async def step_async(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) + self.do_log_stats(scheduler_outputs, output) return request_outputs @@ -707,9 +707,13 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() - async def do_log_stats(self) -> None: + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() # type: ignore + await self.engine.do_log_stats.remote( # type: ignore + scheduler_outputs, model_output) else: self.engine.do_log_stats() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4caecb8a51598..19e7143ac2b45 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -597,16 +597,18 @@ def step(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log( - self._get_stats(scheduler_outputs, model_output=output)) + self.do_log_stats(scheduler_outputs, output) return request_outputs - def do_log_stats(self) -> None: + def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs=None)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output)) def _get_stats( self,