Skip to content

Commit

Permalink
[torch.compile] integration with compilation control (vllm-project#9058)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Oct 10, 2024
1 parent 78c0b41 commit e4d652e
Show file tree
Hide file tree
Showing 22 changed files with 404 additions and 98 deletions.
20 changes: 12 additions & 8 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ steps:
- vllm/core/
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py

Expand Down Expand Up @@ -231,14 +233,16 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- pytest -v -s compile/test_basic_correctness.py

- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
# TODO: re-write in comparison tests, and fix symbolic shape
# for quantization ops.
# - label: "PyTorch Fullgraph Test" # 18min
# source_file_dependencies:
# - vllm/
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py

- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]
Expand Down Expand Up @@ -394,7 +398,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
Expand Down
48 changes: 48 additions & 0 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict, List, Optional

import pytest

from vllm.compilation.levels import CompilationLevel
from vllm.utils import cuda_device_count_stateless

from ..utils import compare_all_settings


# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
if not fullgraph:
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
all_envs: List[Optional[Dict[str, str]]] = [{
"VLLM_TORCH_COMPILE_LEVEL":
str(level)
} for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]]
compare_all_settings(model, all_args, all_envs, method=method)
15 changes: 11 additions & 4 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import pytest

from vllm.compilation.backends import vllm_backend
from vllm.compilation.levels import CompilationLevel

from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support


@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1)
22 changes: 0 additions & 22 deletions tests/compile/test_full_graph_multi_gpu.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/compile/test_full_graph_smoke.py

This file was deleted.

24 changes: 9 additions & 15 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@

from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip

TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
Expand Down Expand Up @@ -68,20 +61,21 @@
}))


def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
def check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
or quantization == "gptq_marlin_24"
) and optimization_level >= CompilationLevel.INDUCTOR:
return

set_torch_compile_backend(backend)

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand Down
4 changes: 3 additions & 1 deletion tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import depyf

from vllm.compilation.levels import CompilationLevel

# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)

temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
Expand Down
13 changes: 8 additions & 5 deletions tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from vllm.compilation.levels import CompilationLevel

from ..utils import compare_two_settings

# --enforce-eager on TPU causes graph compilation
Expand All @@ -9,8 +11,9 @@


def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
115 changes: 114 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import copy
import operator
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.fx as fx

from vllm.logger import init_logger

from .compile_context import get_compile_context
from .levels import CompilationLevel

logger = init_logger(__name__)


def fix_functionalization(graph: fx.Graph):
"""
Expand Down Expand Up @@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f)


def vllm_backend(graph, example_inputs):
def wrap_inductor(graph, example_inputs, additional_inductor_config):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx

if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)


def vllm_backend(
graph,
example_inputs,
additional_inductor_config: Optional[Dict] = None) -> Callable:

context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context

# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}

# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}

# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)

# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
]

first_run = True

# this is the function we return to Dynamo to run finally
def compiled_graph_wrapper(*args):

runtime_shapes: Tuple[int,
...] = tuple(args[i] for i in sym_shape_indices)

nonlocal first_run
nonlocal runtime_shapes_to_compile_flags
nonlocal runtime_shapes_to_compiled_graph

if first_run:
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)

if runtime_shapes not in runtime_shapes_to_compile_flags:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize

if not runtime_shapes_to_compile_flags[runtime_shapes]:
# we don't need to specialize for this shape
return graph_for_symbolic_shape(*args)

if runtime_shapes not in runtime_shapes_to_compiled_graph:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)

return runtime_shapes_to_compiled_graph[runtime_shapes](*args)

return compiled_graph_wrapper


def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"

from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()

if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True

from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)

return backend
23 changes: 23 additions & 0 deletions vllm/compilation/compile_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from contextlib import contextmanager
from typing import Any

_compile_context: Any = None


def get_compile_context() -> Any:
"""Get the current compile context."""
return _compile_context


@contextmanager
def set_compile_context(context: Any):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global _compile_context
prev_context = _compile_context
_compile_context = context
try:
yield
finally:
_compile_context = prev_context
Loading

0 comments on commit e4d652e

Please sign in to comment.