Skip to content

Commit

Permalink
[V1] Enable Inductor when using piecewise CUDA graphs (vllm-project#1…
Browse files Browse the repository at this point in the history
…0268)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
  • Loading branch information
WoosukKwon authored Nov 12, 2024
1 parent 8a06428 commit 1f55e05
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,17 @@ def execute_model(

def load_model(self) -> None:
if self.use_cuda_graph:
# FIXME(woosuk): Currently, we do not use inductor to reduce the
# compilation time and any potential issues with the inductor.
os.environ["VLLM_CUSTOM_OPS"] = "all"
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=False,
use_inductor=True,
enable_fusion=False,
))

logger.info("Starting to load model %s...", self.model_config.model)
Expand Down

0 comments on commit 1f55e05

Please sign in to comment.