diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index daec46760117d..d5d02fdeb7f4b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -183,7 +183,16 @@ steps: - vllm/ - tests/v1 commands: - - VLLM_USE_V1=1 pytest -v -s v1 + # split the test to avoid interference + - VLLM_USE_V1=1 pytest -v -s v1/core + - VLLM_USE_V1=1 pytest -v -s v1/engine + - VLLM_USE_V1=1 pytest -v -s v1/sample + - VLLM_USE_V1=1 pytest -v -s v1/worker + - VLLM_USE_V1=1 pytest -v -s v1/test_stats.py + - VLLM_USE_V1=1 pytest -v -s v1/test_utils.py + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - VLLM_USE_V1=1 pytest -v -s v1/e2e - label: Examples Test # 25min working_dir: "/vllm-workspace/examples" diff --git a/Dockerfile b/Dockerfile index cb9cf0da5be65..0b9f74e08dc68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \ #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base +# TODO: Restore to base image after FlashInfer AOT wheel fixed +FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace @@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose +# How to build this FlashInfer wheel: +# $ export FLASHINFER_ENABLE_AOT=1 +# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ +# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX' +# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive +# $ cd flashinfer +# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4 +# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose + RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ + python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ fi COPY examples examples + +# Although we build Flashinfer with AOT mode, there's still +# some issues w.r.t. JIT compilation. Therefore we need to +# install build dependencies for JIT compilation. +# TODO: Remove this once FlashInfer AOT wheel is fixed +COPY requirements-build.txt requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install -r requirements-build.txt + #################### vLLM installation IMAGE #################### #################### TEST IMAGE #################### diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 31a101e48e026..23285040642a8 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -61,9 +61,10 @@ def test_models( if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") - if backend == "XFORMERS" and model == "google/gemma-2-2b-it": + if backend in ("XFORMERS", + "FLASHINFER") and model == "google/gemma-2-2b-it": pytest.skip( - "XFORMERS does not support gemma2 with full context length.") + f"{backend} does not support gemma2 with full context length.") os.environ["VLLM_ATTENTION_BACKEND"] = backend diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 87d5aefea6cb4..1945479fc3031 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -58,7 +58,7 @@ class TestSetting: model_args=["--task", "embed"], pp_size=1, tp_size=1, - attn_backend="FLASHINFER", + attn_backend="FLASH_ATTN", method="encode", fullgraph=True, ), diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index a2c8f71665737..1645ef911d697 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv( use_tensor_cores=( (num_query_heads//num_kv_heads) > 4) ) - wrapper.begin_forward(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - data_type=dtype) - - output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = wrapper.run(query, key_value_cache) ref_output = ref_paged_attn(query=query, key_cache=key_cache, @@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD") - wrapper.begin_forward( + wrapper.plan( qo_indptr, kv_indptr, kv_indices, @@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], num_kv_heads, head_size, block_size, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap, ) - output = wrapper.forward( + output = wrapper.run( query, key_value_cache, - logits_soft_cap=soft_cap, ) ref_output = ref_paged_attn(query=query, @@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_tables=block_tables, scale=scale, soft_cap=soft_cap) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD") - wrapper.begin_forward( + wrapper.plan( qo_indptr, kv_indptr, kv_indices, @@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv( num_kv_heads, head_size, block_size, + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap, ) - output = wrapper.forward(query, - kv_cache_fp8, - logits_soft_cap=soft_cap, - k_scale=k_scale, - v_scale=v_scale) + output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) ref_output = ref_paged_attn(query=query, key_cache=key_cache.squeeze(1), @@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( del query del block_tables # verify prefill fp8 - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv( wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores) - wrapper.begin_forward(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - data_type=dtype, - q_data_type=dtype) - output = wrapper.forward(query, - kv_cache_fp8, - logits_soft_cap=soft_cap, - k_scale=k_scale, - v_scale=v_scale) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap) + output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3135b0b405343..7cccef9608218 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,3 +1,4 @@ +import dataclasses from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -13,9 +14,11 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: - BatchDecodeWithPagedKVCacheWrapper = None - CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None - BatchPrefillWithPagedKVCacheWrapper = None + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 import torch @@ -30,7 +33,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig, get_current_vllm_config from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -99,6 +104,72 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = vllm_config.compilation_config.static_forward_context + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + assert isinstance(layer, Attention) + + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + class FlashInferState(AttentionState): def __init__(self, runner): @@ -108,6 +179,11 @@ def __init__(self, runner): self._decode_wrapper = None self._prefill_wrapper = None + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.empty( @@ -215,6 +291,9 @@ def graph_capture_get_metadata_for_batch( batch_size + 1, dtype=torch.int32) + global_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], @@ -238,7 +317,9 @@ def graph_capture_get_metadata_for_batch( q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None) + prefill_wrapper=None, + **dataclasses.asdict(global_params), + ) attn_metadata.begin_forward() return attn_metadata @@ -325,9 +406,28 @@ class FlashInferMetadata(AttentionMetadata): data_type: torch.dtype = None # The data type of the query q_data_type: torch.dtype = None - device: torch.device = torch.device("cuda") + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") is_profile_run: bool = False + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 @@ -363,14 +463,21 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( + self.prefill_wrapper.plan( self.query_start_loc, self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -386,8 +493,7 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( + self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], @@ -397,8 +503,11 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, # kv-cache data type. - data_type=self.data_type, + kv_data_type=self.data_type, # query data type. q_data_type=self.q_data_type) @@ -496,6 +605,11 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = get_current_vllm_config() + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] @@ -528,6 +642,20 @@ def prepare(self): self.total_blocks = 0 self.is_profile_run: bool = False + if self.global_hyperparameters is None: + # Infer global hyperparameters, since currently we only support + # models in which all layers share the same values for the + # following hyperparameters: + # - `window_left` + # - `logits_soft_cap` + # - `sm_scale` + inferred_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + self.global_hyperparameters = inferred_params + self.window_left = inferred_params.window_left + self.logits_soft_cap = inferred_params.logits_soft_cap + self.sm_scale = inferred_params.sm_scale + def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): @@ -756,7 +884,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], data_type=kv_cache_dtype, q_data_type=self.runner.model_config.dtype, use_cuda_graph=use_captured_graph, - is_profile_run=self.is_profile_run) + is_profile_run=self.is_profile_run, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + ) class FlashInferImpl(AttentionImpl): @@ -885,25 +1017,34 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( + + assert prefill_meta.prefill_wrapper._causal + assert prefill_meta.prefill_wrapper._window_left == window_left + assert prefill_meta.prefill_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale + + prefill_output = prefill_meta.prefill_wrapper.run( query, kv_cache, - logits_soft_cap=logits_soft_cap, - causal=True, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, - window_left=window_left) + ) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.forward( + + assert decode_meta.decode_wrapper._window_left == window_left + assert decode_meta.decode_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert decode_meta.decode_wrapper._sm_scale == softmax_scale + + decode_output = decode_meta.decode_wrapper.run( decode_query, kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, - window_left=window_left) + ) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/config.py b/vllm/config.py index 7a58d64bcc6e2..dc1d611115489 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -310,14 +310,15 @@ def __init__( (self.hf_text_config.model_type in ["gemma2", "cohere2"])) if (not self.disable_sliding_window and has_interleaved_attention): - if envs.VLLM_ATTENTION_BACKEND == "XFORMERS": + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( self.hf_text_config.sliding_window) logger.warning_once( f"{self.hf_text_config.model_type} has interleaved " "attention, which is currently not supported by the " - "XFORMERS backend. Disabling sliding window and capping " + f"{backend} backend. Disabling sliding window and capping " "the max length to the sliding window size " f"({sliding_window_len_min}).") self.disable_sliding_window = True @@ -3310,7 +3311,7 @@ def __str__(self): @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig): +def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): """ Temporarily set the current VLLM config. Used during model initialization. @@ -3330,7 +3331,8 @@ def set_current_vllm_config(vllm_config: VllmConfig): vllm_config.compilation_config.enabled_custom_ops) logger.debug("disabled custom ops: %s", vllm_config.compilation_config.disabled_custom_ops) - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ and compilation_counter.num_models_seen == num_models_seen: # If the model supports compilation, # compilation_counter.num_models_seen should be increased diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index e9779878710ee..527b4307f3670 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -114,7 +114,7 @@ def _initialize_model( all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - with set_current_vllm_config(vllm_config): + with set_current_vllm_config(vllm_config, check_compile=True): return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " @@ -142,7 +142,7 @@ def _initialize_model( kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config): + with set_current_vllm_config(vllm_config, check_compile=True): return model_class(**kwargs) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 5b4757072353f..e359aef9dcb7f 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -288,7 +288,8 @@ def _init_model(self): model_args.torch_dtype = self.tensorizer_config.dtype assert self.tensorizer_config.model_class is not None # TODO: Do we need to consider old-style model class? - with no_init_or_tensor(), set_current_vllm_config(self.vllm_config): + with no_init_or_tensor(), set_current_vllm_config(self.vllm_config, + check_compile=True): return self.tensorizer_config.model_class( vllm_config=self.vllm_config, ) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c6e6693c54f57..6eeb4aa17051f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -8,7 +8,8 @@ import torch import torch.nn as nn -from vllm.config import ObservabilityConfig, VllmConfig +from vllm.config import (ObservabilityConfig, VllmConfig, + set_current_vllm_config) from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -498,8 +499,11 @@ def __init__( group. """ self.rpc_rank = rpc_rank - self.vllm_config = vllm_config self.worker: Optional[WorkerBase] = None + # do not store this `vllm_config`, `init_worker` will set the final + # one. TODO: investigate if we can remove this field in + # `WorkerWrapperBase`, `init_cached_hf_modules` should be + # unnecessary now. if vllm_config.model_config is not None: # it can be None in tests trust_remote_code = vllm_config.model_config.trust_remote_code @@ -533,6 +537,9 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: Arguments are passed to the worker class constructor. """ kwargs = all_kwargs[self.rpc_rank] + self.vllm_config = kwargs.get("vllm_config", None) + assert self.vllm_config is not None, ( + "vllm_config is required to initialize the worker") enable_trace_function_call_for_thread(self.vllm_config) from vllm.plugins import load_general_plugins @@ -546,8 +553,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: bytes) worker_class = cloudpickle.loads( self.vllm_config.parallel_config.worker_cls) - self.worker = worker_class(**kwargs) - assert self.worker is not None + with set_current_vllm_config(self.vllm_config): + # To make vLLM config available during worker initialization + self.worker = worker_class(**kwargs) + assert self.worker is not None def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: