Skip to content

Commit

Permalink
[9/N] torch.compile LLM usage (vllm-project#10552)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Nov 22, 2024
1 parent aed0748 commit 33e0a25
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
5 changes: 2 additions & 3 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import depyf

from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationLevel

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
Expand Down Expand Up @@ -34,8 +34,7 @@
# all the control
llm = LLM(model="google/gemma-2b",
enforce_eager=True,
compilation_config=CompilationConfig(
level=CompilationLevel.DYNAMO_AS_IS))
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
Expand Down
15 changes: 14 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Expand All @@ -9,6 +10,7 @@
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
Expand Down Expand Up @@ -107,13 +109,16 @@ class LLM:
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
compilation_config: Either an integer or a dictionary. If it is an integer,
it is used as the level of compilation optimization. If it is a dictionary,
it can specify the full compilation configuration.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
""" # noqa

DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
Expand Down Expand Up @@ -166,6 +171,7 @@ def __init__(
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
**kwargs,
) -> None:
'''
Expand All @@ -178,6 +184,12 @@ def __init__(
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

if compilation_config is not None:
compilation_config_instance = CompilationConfig.from_cli(
json.dumps(compilation_config))
else:
compilation_config_instance = None

engine_args = EngineArgs(
model=model,
task=task,
Expand All @@ -202,6 +214,7 @@ def __init__(
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
# Logic to switch between engines is done at runtime instead of import
Expand Down

0 comments on commit 33e0a25

Please sign in to comment.