From 33e0a2540a6bff23cbc6a4b8f7a6784a2bc87d47 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 19:13:31 -0800 Subject: [PATCH] [9/N] torch.compile LLM usage (#10552) Signed-off-by: youkaichao --- tests/tpu/test_compilation.py | 5 ++--- vllm/entrypoints/llm.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 65bee85e7a1ea..b7124ebc1b0f3 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -4,7 +4,7 @@ import depyf -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): @@ -34,8 +34,7 @@ # all the control llm = LLM(model="google/gemma-2b", enforce_eager=True, - compilation_config=CompilationConfig( - level=CompilationLevel.DYNAMO_AS_IS)) + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d9..2446a64a02eb2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ import itertools +import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -9,6 +10,7 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -107,13 +109,16 @@ class LLM: hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an integer, + it is used as the level of compilation optimization. If it is a dictionary, + it can specify the full compilation configuration. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. - """ + """ # noqa DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" @@ -166,6 +171,7 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -178,6 +184,12 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if compilation_config is not None: + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) + else: + compilation_config_instance = None + engine_args = EngineArgs( model=model, task=task, @@ -202,6 +214,7 @@ def __init__( hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, **kwargs, ) # Logic to switch between engines is done at runtime instead of import