From e4d652ea3ed9b2a60c1582cb2e2605695e61280f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 10 Oct 2024 12:39:36 -0700 Subject: [PATCH] [torch.compile] integration with compilation control (#9058) --- .buildkite/test-pipeline.yaml | 20 ++-- tests/compile/test_basic_correctness.py | 48 +++++++++ tests/compile/test_full_graph.py | 15 ++- tests/compile/test_full_graph_multi_gpu.py | 22 ---- tests/compile/test_full_graph_smoke.py | 13 --- tests/compile/utils.py | 24 ++--- tests/tpu/test_compilation.py | 4 +- tests/tpu/test_custom_dispatcher.py | 13 ++- vllm/compilation/backends.py | 115 ++++++++++++++++++++- vllm/compilation/compile_context.py | 23 +++++ vllm/compilation/decorators.py | 85 +++++++++++++++ vllm/compilation/levels.py | 9 ++ vllm/compilation/wrapper.py | 27 ++++- vllm/envs.py | 16 +-- vllm/model_executor/custom_op.py | 3 +- vllm/model_executor/models/gemma2.py | 2 + vllm/model_executor/models/llama.py | 2 + vllm/model_executor/models/llava.py | 8 +- vllm/platforms/tpu.py | 14 +++ vllm/plugins/__init__.py | 14 ++- vllm/sequence.py | 7 +- vllm/worker/model_runner.py | 18 +++- 22 files changed, 404 insertions(+), 98 deletions(-) create mode 100644 tests/compile/test_basic_correctness.py delete mode 100644 tests/compile/test_full_graph_multi_gpu.py delete mode 100644 tests/compile/test_full_graph_smoke.py create mode 100644 vllm/compilation/compile_context.py create mode 100644 vllm/compilation/decorators.py create mode 100644 vllm/compilation/levels.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ccc5003e66beb..ae8e03a2fdf8f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -121,7 +121,9 @@ steps: - vllm/core/ - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile commands: + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py @@ -231,14 +233,16 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph_smoke.py + - pytest -v -s compile/test_basic_correctness.py -- label: "PyTorch Fullgraph Test" # 18min - source_file_dependencies: - - vllm/ - - tests/compile - commands: - - pytest -v -s compile/test_full_graph.py +# TODO: re-write in comparison tests, and fix symbolic shape +# for quantization ops. +# - label: "PyTorch Fullgraph Test" # 18min +# source_file_dependencies: +# - vllm/ +# - tests/compile +# commands: +# - pytest -v -s compile/test_full_graph.py - label: Kernels Test %N # 1h each mirror_hardwares: [amd] @@ -394,7 +398,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph_multi_gpu.py + - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py new file mode 100644 index 0000000000000..b6ec7413978f4 --- /dev/null +++ b/tests/compile/test_basic_correctness.py @@ -0,0 +1,48 @@ +from typing import Dict, List, Optional + +import pytest + +from vllm.compilation.levels import CompilationLevel +from vllm.utils import cuda_device_count_stateless + +from ..utils import compare_all_settings + + +# we cannot afford testing the full Catesian product +# of all models and all levels +@pytest.mark.parametrize( + "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", + [ + ("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate", + True), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", + ["--quantization", "compressed-tensors" + ], 1, 1, "FLASH_ATTN", "generate", True), + ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True), + # TODO: add multi-modality test for llava + ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) + ]) +def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, + method, fullgraph): + # this test is run under multiple suits, with different GPUs. + # make sure we only run the test with correct CUDA devices. + # don't use "<", as it will duplicate the tests. + if cuda_device_count_stateless() != pp_size * tp_size: + pytest.skip("Not correct CUDA devices for the test.") + import os + os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend + if not fullgraph: + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" + all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 + # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case + # inductor will change the output, so we cannot compare them. + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TORCH_COMPILE_LEVEL": + str(level) + } for level in [ + CompilationLevel.NO_COMPILATION, + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, + ]] + compare_all_settings(model, all_args, all_envs, method=method) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5dd65ad7236f9..f28f9145bb442 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,13 +1,20 @@ import pytest -from vllm.compilation.backends import vllm_backend +from vllm.compilation.levels import CompilationLevel +from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR]) +@fork_new_process_for_each_test +def test_full_graph(model_info, optimization_level): model = model_info[0] model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py deleted file mode 100644 index e9883d5254e72..0000000000000 --- a/tests/compile/test_full_graph_multi_gpu.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -@fork_new_process_for_each_test -def test_full_graph_multi_gpu(model_info, tp_size, backend): - model = model_info[0] - model_kwargs = model_info[1] - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py deleted file mode 100644 index 0c5a95b4ead4c..0000000000000 --- a/tests/compile/test_full_graph_smoke.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend - -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): - model = model_info[0] - model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 2d06a0946d911..5386eb0e3795d 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,16 +4,9 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.plugins import set_torch_compile_backend +from vllm.compilation.levels import CompilationLevel from vllm.utils import is_hip -TEST_MODELS_SMOKE = [ - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { - "quantization": "compressed-tensors" - }), - ("meta-llama/Meta-Llama-3-8B", {}), -] - TEST_MODELS = [ ("facebook/opt-125m", {}), ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { @@ -68,20 +61,21 @@ })) -def check_full_graph_support(model, model_kwargs, backend, tp_size=1): +def check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1): # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" - os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. quantization = model_kwargs.get("quantization") if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24") and backend != "eager": + or quantization == "gptq_marlin_24" + ) and optimization_level >= CompilationLevel.INDUCTOR: return - set_torch_compile_backend(backend) - prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index d8df86b2aaa14..86d9af88e49ea 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,9 +5,11 @@ import depyf +from vllm.compilation.levels import CompilationLevel + # disable custom dispatcher, let Dynamo takes over # all the control -os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" +os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS) temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 69ab67abdd12b..923d0f1680802 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,7 @@ import os +from vllm.compilation.levels import CompilationLevel + from ..utils import compare_two_settings # --enforce-eager on TPU causes graph compilation @@ -9,8 +11,9 @@ def test_custom_dispatcher(): - compare_two_settings("google/gemma-2b", - arg1=["--enforce-eager"], - arg2=["--enforce-eager"], - env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, - env2={}) + compare_two_settings( + "google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, + env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)}) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de0b1d8a75757..4780358cea517 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,8 +1,17 @@ +import copy import operator +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.fx as fx +from vllm.logger import init_logger + +from .compile_context import get_compile_context +from .levels import CompilationLevel + +logger = init_logger(__name__) + def fix_functionalization(graph: fx.Graph): """ @@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -def vllm_backend(graph, example_inputs): +def wrap_inductor(graph, example_inputs, additional_inductor_config): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx + + if additional_inductor_config is not None: + current_config.update(additional_inductor_config) + if current_config['post_grad_custom_post_pass'] is not None: + logger.warning( + "post_grad_custom_post_pass is already set in the config. " + "Overwriting it with the fix_functionalization") current_config['post_grad_custom_post_pass'] = fix_functionalization return compile_fx(graph, example_inputs, config_patches=current_config) + + +def vllm_backend( + graph, + example_inputs, + additional_inductor_config: Optional[Dict] = None) -> Callable: + + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + + # flags for all the seen shapes, whether we need to specialize + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + + # if we need to specialize, the compiled graph for that shape + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + logger.info("Compiling a graph for general shapes") + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, + additional_inductor_config) + + # TODO: Dynamo does not pass all dynamic shapes. + # Need to investigate why. It works now because all the dynamic + # shapes have the same value, and either of them can be used. + sym_shape_indices = [ + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) + ] + + first_run = True + + # this is the function we return to Dynamo to run finally + def compiled_graph_wrapper(*args): + + runtime_shapes: Tuple[int, + ...] = tuple(args[i] for i in sym_shape_indices) + + nonlocal first_run + nonlocal runtime_shapes_to_compile_flags + nonlocal runtime_shapes_to_compiled_graph + + if first_run: + # the first compilation is for profiling, we directly run it + first_run = False + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compile_flags: + # we haven't seen this shape before + # query if we need to specialize for this shape + # we only specialize for the first dimension. + # TODO: investigate if any model needs to specialize + # beyond the first dimension + runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[ + 0] in sizes_to_specialize + + if not runtime_shapes_to_compile_flags[runtime_shapes]: + # we don't need to specialize for this shape + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compiled_graph: + # we need to specialize for this shape, and we haven't compiled + # compile the graph for this shape + logger.info("Compiling a graph for shapes %s", runtime_shapes) + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( + graph, args, additional_inductor_config) + + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + + return compiled_graph_wrapper + + +def select_default_backend(level: int) -> Union[str, Callable]: + if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: + backend = "eager" + return backend + assert level in [ + CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE + ], f"Invalid level {level}" + + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE: + if "max_autotune" in additional_configs and not additional_configs[ + "max_autotune"]: + logger.warning( + "max_autotune is disabled, but is overridden by level %s", + CompilationLevel.INDUCTOR_MAX_AUTOTUNE) + additional_configs['max_autotune'] = True + + from functools import partial + backend = partial(vllm_backend, + additional_inductor_config=additional_configs) + + return backend diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py new file mode 100644 index 0000000000000..29db3d4c637b9 --- /dev/null +++ b/vllm/compilation/compile_context.py @@ -0,0 +1,23 @@ +from contextlib import contextmanager +from typing import Any + +_compile_context: Any = None + + +def get_compile_context() -> Any: + """Get the current compile context.""" + return _compile_context + + +@contextmanager +def set_compile_context(context: Any): + """A context manager that stores the current compile context, + usually it is a list of sizes to specialize. + """ + global _compile_context + prev_context = _compile_context + _compile_context = context + try: + yield + finally: + _compile_context = prev_context diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 0000000000000..b790e5550adb7 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,85 @@ +from typing import List, Optional, Union + +import torch + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.compilation.levels import CompilationLevel +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + + +def support_compile_llama_style(cls: type): + """ + A decorator to add support for compiling the forward method of a class. + If a module's **forward signature** is compatible with llama, this + decorator can be used to enable the compilation of the forward method. + """ + + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + if envs.VLLM_TORCH_COMPILE_LEVEL in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo(): + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *args, **kwargs): + old_init(self, *args, **kwargs) + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + cls.__init__ = __init__ + + def __call__( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + if torch.compiler.is_compiling(): + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + if input_ids is not None: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + if inputs_embeds is not None: + torch._dynamo.mark_dynamic(inputs_embeds, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py new file mode 100644 index 0000000000000..162bf5ae64997 --- /dev/null +++ b/vllm/compilation/levels.py @@ -0,0 +1,9 @@ +# constants for the levels of the compilation process + + +class CompilationLevel: + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + INDUCTOR = 3 + INDUCTOR_MAX_AUTOTUNE = 4 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e923bd36ccc08..1594b64a61b94 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -3,12 +3,14 @@ from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List +from typing import Callable, List, Optional import torch import vllm.envs as envs +from .levels import CompilationLevel + class TorchCompileWrapperWithCustomDispatcher: """ @@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Callable): + def __init__(self, compiled_callable: Optional[Callable] = None): + + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + # choose the compile backend + + # if the user has set the backend, use it + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() + if backend is None: + from vllm.compilation.backends import select_default_backend + backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] @@ -33,7 +54,7 @@ def __init__(self, compiled_callable: Callable): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/envs.py b/vllm/envs.py index 97767bf5b5ad9..8b541e5b78c01 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -65,6 +65,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False + VLLM_TORCH_COMPILE_LEVEL: int = 0 def get_default_cache_root(): @@ -198,23 +199,12 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), - # Internal flag to enable Dynamo graph capture - "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": - lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), - "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": - lambda: - (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in - ("true", "1")), - - # Internal flag to control whether we use custom op, - # or use the native pytorch implementation - "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": - lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), - # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + "VLLM_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), # local rank of the process in the distributed setting, used to determine # the GPU device id diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9102b5e19ebec..d0e90245ad010 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,6 +1,7 @@ import torch.nn as nn import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu @@ -55,7 +56,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: return self.forward_native if is_hip(): diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index c442b6d2e7c96..edc71435b551f 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -21,6 +21,7 @@ from transformers import Gemma2Config from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -238,6 +239,7 @@ def forward( return hidden_states, residual +@support_compile_llama_style class Gemma2Model(nn.Module): def __init__( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2a79a9edf2111..3f17e9004c30f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,6 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -265,6 +266,7 @@ def forward( return hidden_states, residual +@support_compile_llama_style class LlamaModel(nn.Module): def __init__( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a3acb93dc3c11..864b9ff66a84e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -365,6 +365,8 @@ def forward( input_ids = None inputs_embeds = None else: + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: @@ -375,10 +377,10 @@ def forward( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) - - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index a35777f91cac9..8ba973b28263f 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,7 +1,21 @@ +import os + import torch +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.plugins import set_torch_compile_backend + from .interface import Platform, PlatformEnum +if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) + +assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\ + "TPU does not support Inductor." + +set_torch_compile_backend("openxla") + class TpuPlatform(Platform): _enum = PlatformEnum.TPU diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7939688ef0da3..211fedbc6e2ec 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import vllm.envs as envs @@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]): def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend + + +_inductor_additional_configs: Dict = {} + + +def set_inductor_additional_configs(configs: Dict): + global _inductor_additional_configs + _inductor_additional_configs = configs + + +def get_inductor_additional_configs() -> Dict: + return _inductor_additional_configs diff --git a/vllm/sequence.py b/vllm/sequence.py index 0c27ffca36cfd..51be9466e66be 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1137,10 +1137,9 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -class IntermediateTensors( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0bd2958816718..5bc7100732291 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,6 +18,8 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -1126,10 +1128,10 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): - from vllm.compilation.backends import vllm_backend + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ + and supports_dynamo(): from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend + backend = get_torch_compile_backend() or "eager" self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, @@ -1289,7 +1291,15 @@ def profile_run(self) -> None: batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) - self.execute_model(model_input, kv_caches, intermediate_tensors) + + graph_batch_size = self.max_batchsize_to_capture + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + if self.model_config.enforce_eager: + batch_size_capture_list = [] + with set_compile_context(batch_size_capture_list): + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return