From 151a6093b8ad389fc416f7a2815dbe4702715dad Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 6 Nov 2024 15:02:34 -0700 Subject: [PATCH 01/13] Add hf_transfer to testing image --- Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile b/Dockerfile index 343364da2ebf5..922183e21444d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,6 +48,8 @@ COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install hf_transfer # cuda arch list used by torch From db7db4aab9fd23e818d89ca9037099d30c071a5a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 7 Nov 2024 14:00:21 +0800 Subject: [PATCH 02/13] [Misc] Consolidate ModelConfig code related to HF config (#10104) Signed-off-by: DarkLight1337 --- docs/source/serving/compatibility_matrix.rst | 2 +- tests/test_config.py | 38 ++++++++++++++++++++ vllm/config.py | 14 ++++---- vllm/inputs/preprocess.py | 2 +- vllm/transformers_utils/config.py | 9 +++++ vllm/utils.py | 4 --- vllm/worker/cpu_model_runner.py | 9 +---- vllm/worker/cpu_worker.py | 5 +-- vllm/worker/model_runner.py | 23 +++++------- vllm/worker/worker.py | 5 +-- 10 files changed, 68 insertions(+), 43 deletions(-) diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index cab19e4ec5b6c..f629b3ca78318 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -359,7 +359,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - `✗ `__ + - ✅ - ✗ * - :abbr:`logP (Logprobs)` - ✅ diff --git a/tests/test_config.py b/tests/test_config.py index 69918b67607d9..5211049bf0011 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -165,3 +165,41 @@ def test_rope_customization(): assert getattr(longchat_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING assert longchat_model_config.max_model_len == 4096 + + +@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ + ("facebook/opt-125m", False), + ("facebook/bart-base", True), + ("meta-llama/Llama-3.2-1B", False), + ("meta-llama/Llama-3.2-11B-Vision", True), +]) +def test_is_encoder_decoder(model_id, is_encoder_decoder): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + + assert config.is_encoder_decoder == is_encoder_decoder + + +@pytest.mark.parametrize(("model_id", "uses_mrope"), [ + ("facebook/opt-125m", False), + ("Qwen/Qwen2-VL-2B-Instruct", True), +]) +def test_uses_mrope(model_id, uses_mrope): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + + assert config.uses_mrope == uses_mrope diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b3..c7fad3a261858 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -15,7 +15,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config) + get_hf_text_config, + is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -667,12 +668,13 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config @property - def is_encoder_decoder_model(self) -> bool: + def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr( - self.hf_config, "is_encoder_decoder", - False) or (hasattr(self.hf_config, "text_config") and getattr( - self.hf_config.text_config, "is_encoder_decoder", False)) + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) @property def is_multimodal_model(self) -> bool: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a5c787a56b5a9..509b0448b9e51 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -580,4 +580,4 @@ async def preprocess_async( ) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.model_config.is_encoder_decoder diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1a5870aa4f84c..415d8bf7cc2bb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool: return "mrope_section" in rope_scaling +def is_encoder_decoder(config: PretrainedConfig) -> bool: + """Detect if the model with this config is used as an encoder/decoder.""" + text_config = getattr(config, "text_config", None) + if text_config is not None: + return is_encoder_decoder(text_config) + + return getattr(config, "is_encoder_decoder", False) + + def get_config( model: Union[str, Path], trust_remote_code: bool, diff --git a/vllm/utils.py b/vllm/utils.py index d78130873d3dc..13d7f6d475346 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -88,9 +88,6 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " - "encoder/decoder models.") - # Efficiently import all enc/dec error strings # rather than having to import all of the above STR_NOT_IMPL_ENC_DEC_ERR_STRS = { @@ -105,7 +102,6 @@ "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, - "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU } # Constants related to forcing the attention backend selection diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fdd72a452f2ad..26a15ed645c43 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -18,7 +18,6 @@ MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.transformers_utils.config import uses_mrope from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -163,7 +162,7 @@ def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, # special processing for mrope position deltas. mrope_positions = None - if self.runner.model_is_mrope: + if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -446,12 +445,6 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model - @property - def model_is_mrope(self) -> bool: - """Detect if the model has "mrope" rope_scaling type. - mrope requires keep "rope_deltas" between prompt and decoding phases.""" - return uses_mrope(self.model_config.hf_config) - def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3778707ae07e8..2914f520d823c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,7 +151,7 @@ def __init__( self.local_omp_cpuid = omp_cpuids.split("|")[rank] ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner - if self._is_encoder_decoder_model(): + if self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunner = ModelRunnerClass( vllm_config=vllm_config, @@ -188,9 +188,6 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model - def init_device(self) -> None: if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1e8ea4e8e79cf..a1ec2e85be7b8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -47,7 +47,6 @@ LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.transformers_utils.config import uses_mrope from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_pin_memory_available, supports_dynamo, @@ -493,7 +492,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) elif self.runner.scheduler_config.is_multi_step or \ - self.runner.model_config.is_encoder_decoder_model: + self.runner.model_config.is_encoder_decoder: context_len = seq_len - 1 else: context_len = seq_data.get_num_computed_tokens() @@ -666,7 +665,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. - if self.runner.model_is_mrope: + if self.runner.model_config.uses_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -711,7 +710,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): encoder_seq_len = 0 - if self.runner.model_config.is_encoder_decoder_model: + if self.runner.model_config.is_encoder_decoder: encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() inter_data = self.init_cached_inter_data( @@ -837,7 +836,7 @@ def build(self) -> ModelInputForGPU: if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - if self.runner.model_config.is_encoder_decoder_model: + if self.runner.model_config.is_encoder_decoder: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) @@ -1375,12 +1374,6 @@ def list_prompt_adapters(self) -> Set[int]: raise RuntimeError("PromptAdapter is not enabled.") return self.prompt_adapter_manager.list_adapters() - @property - def model_is_mrope(self) -> bool: - """Detect if the model has "mrope" rope_scaling type. - mrope requires keep "rope_deltas" between prompt and decoding phases.""" - return uses_mrope(self.model_config.hf_config) - @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. @@ -1411,7 +1404,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() - if self.model_is_mrope: + if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)) # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. @@ -1447,7 +1440,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.attn_state.graph_capture_get_metadata_for_batch( batch_size, is_encoder_decoder_model=self.model_config. - is_encoder_decoder_model)) + is_encoder_decoder)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1466,7 +1459,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), self.attn_state.graph_clone(batch_size), - self.model_config.is_encoder_decoder_model) + self.model_config.is_encoder_decoder) capture_inputs = { "input_ids": @@ -1497,7 +1490,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) - if self.model_config.is_encoder_decoder_model: + if self.model_config.is_encoder_decoder: # add the additional inputs to capture for # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8928936b4f9fc..d8c8011a585d8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -77,7 +77,7 @@ def __init__( ModelRunnerClass = model_runner_cls elif model_config.task == "embedding": ModelRunnerClass = EmbeddingModelRunner - elif self._is_encoder_decoder_model(): + elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( vllm_config=self.vllm_config, @@ -119,9 +119,6 @@ def stop_profile(self): raise RuntimeError("Profiler is not enabled.") self.profiler.stop() - def _is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model - def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until From 104d729656fe746d1b91a0528e51e5efc8d14b4a Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 7 Nov 2024 01:54:46 -0500 Subject: [PATCH 03/13] [CI/Build] re-add codespell to CI (#10083) Signed-off-by: Russell Bryant --- .github/workflows/codespell.yml | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 .github/workflows/codespell.yml diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml new file mode 100644 index 0000000000000..dfb087ff66913 --- /dev/null +++ b/.github/workflows/codespell.yml @@ -0,0 +1,45 @@ +name: codespell + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - "**/*.md" + - "**/*.rst" + - pyproject.toml + - requirements-lint.txt + - .github/workflows/codespell.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - "**/*.md" + - "**/*.rst" + - pyproject.toml + - requirements-lint.txt + - .github/workflows/codespell.yml + +jobs: + codespell: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Spelling check with codespell + run: | + codespell --toml pyproject.toml From d7263a1bb837648bec67d99ed35db56c58832d3f Mon Sep 17 00:00:00 2001 From: Rafael Vasquez Date: Thu, 7 Nov 2024 02:50:35 -0500 Subject: [PATCH 04/13] Doc: Improve benchmark documentation (#9927) Signed-off-by: Rafael Vasquez --- docs/source/dev/profiling/profiling_index.rst | 5 +-- docs/source/index.rst | 4 +-- docs/source/performance/benchmarks.rst | 33 +++++++++++++++++++ .../performance_benchmark/benchmarks.rst | 23 ------------- 4 files changed, 38 insertions(+), 27 deletions(-) create mode 100644 docs/source/performance/benchmarks.rst delete mode 100644 docs/source/performance_benchmark/benchmarks.rst diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index 9e8b2f1817567..a422b1fcda521 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -1,5 +1,6 @@ -Profiling vLLM -================================= +============== +Profiling vLLM +============== We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/`` diff --git a/docs/source/index.rst b/docs/source/index.rst index 51add1fd4d0ab..38dad25e18c02 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -126,9 +126,9 @@ Documentation .. toctree:: :maxdepth: 1 - :caption: Performance benchmarks + :caption: Performance - performance_benchmark/benchmarks + performance/benchmarks .. toctree:: :maxdepth: 2 diff --git a/docs/source/performance/benchmarks.rst b/docs/source/performance/benchmarks.rst new file mode 100644 index 0000000000000..6d4d7b544cb5d --- /dev/null +++ b/docs/source/performance/benchmarks.rst @@ -0,0 +1,33 @@ +.. _benchmarks: + +================ +Benchmark Suites +================ + +vLLM contains two sets of benchmarks: + ++ :ref:`Performance benchmarks ` ++ :ref:`Nightly benchmarks ` + + +.. _performance_benchmarks: + +Performance Benchmarks +---------------------- + +The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the ``perf-benchmarks`` and ``ready`` labels, and when a PR is merged into vLLM. + +The latest performance results are hosted on the public `vLLM Performance Dashboard `_. + +More information on the performance benchmarks and their parameters can be found `here `__. + +.. _nightly_benchmarks: + +Nightly Benchmarks +------------------ + +These compare vLLM's performance against alternatives (``tgi``, ``trt-llm``, and ``lmdeploy``) when there are major updates of vLLM (e.g., bumping up to a new version). They are primarily intended for consumers to evaluate when to choose vLLM over other options and are triggered on every commit with both the ``perf-benchmarks`` and ``nightly-benchmarks`` labels. + +The latest nightly benchmark results are shared in major release blog posts such as `vLLM v0.6.0 `_. + +More information on the nightly benchmarks and their parameters can be found `here `__. \ No newline at end of file diff --git a/docs/source/performance_benchmark/benchmarks.rst b/docs/source/performance_benchmark/benchmarks.rst deleted file mode 100644 index e5c8d6a55de63..0000000000000 --- a/docs/source/performance_benchmark/benchmarks.rst +++ /dev/null @@ -1,23 +0,0 @@ -.. _benchmarks: - -Benchmark suites of vLLM -======================== - - - -vLLM contains two sets of benchmarks: - -+ **Performance benchmarks**: benchmark vLLM's performance under various workloads at a high frequency (when a pull request (PR for short) of vLLM is being merged). See `vLLM performance dashboard `_ for the latest performance results. - -+ **Nightly benchmarks**: compare vLLM's performance against alternatives (tgi, trt-llm, and lmdeploy) when there are major updates of vLLM (e.g., bumping up to a new version). The latest results are available in the `vLLM GitHub README `_. - - -Trigger a benchmark -------------------- - -The performance benchmarks and nightly benchmarks can be triggered by submitting a PR to vLLM, and label the PR with `perf-benchmarks` and `nightly-benchmarks`. - - -.. note:: - - Please refer to `vLLM performance benchmark descriptions `_ and `vLLM nightly benchmark descriptions `_ for detailed descriptions on benchmark environment, workload and metrics. From 6192e9b8fef8492c3e52bd65c7d954a1ef9b40c8 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Wed, 6 Nov 2024 23:50:47 -0800 Subject: [PATCH 05/13] [Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030) Signed-off-by: Hanzhi Zhou --- csrc/custom_all_reduce.cu | 119 +++++++-------- csrc/custom_all_reduce.cuh | 87 +++++------ csrc/custom_all_reduce_test.cu | 24 +-- csrc/ops.h | 22 ++- csrc/torch_bindings.cpp | 21 +-- tests/distributed/test_custom_all_reduce.py | 4 +- tools/profiler/visualize_layerwise_profile.py | 32 ++-- vllm/_custom_ops.py | 29 ++-- .../device_communicators/custom_all_reduce.py | 140 +++++++----------- 9 files changed, 218 insertions(+), 260 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 9b82bec44c3c6..123278bfed71d 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -5,32 +5,29 @@ #include "custom_all_reduce.cuh" -// fake pointer type, must match fptr_t type in ops.h +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink) { - int world_size = offsets.size(); + int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); - if (world_size != handles.size()) - throw std::invalid_argument( - "handles length should equal to offsets length"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); - cudaIpcMemHandle_t ipc_handles[8]; + vllm::Signal* ipc_ptrs[8]; for (int i = 0; i < world_size; i++) { - std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), - rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); + return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), + rank_data.numel(), rank, world_size, + full_nvlink); } /** @@ -55,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - cudaStream_t stream) { +/** + * Performs an out-of-place allreduce and stores result in out. + * + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); + TORCH_CHECK(_is_weak_contiguous(inp)); + auto input_size = inp.numel() * inp.element_size(); + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, + cudaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp.data_ptr(); + } switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), + stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } @@ -85,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - _all_reduce(_fa, inp, out, stream); -} - -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - - auto input_size = inp.numel() * inp.element_size(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), - "registered buffer is too small to contain the input"); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), - input_size, cudaMemcpyDeviceToDevice, stream)); - _all_reduce(_fa, reg_buffer, out, stream); -} - void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - delete fa; + delete reinterpret_cast(_fa); } int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets) { +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); - fa->register_buffer(handles, offsets, t.data_ptr()); + TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + void* ipc_ptrs[8]; + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + fa->register_buffer(ipc_ptrs); } -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa) { +// Use vector to represent byte data for python binding compatibility. +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); - auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = - torch::empty({static_cast(handle_bytes.size())}, options); - std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); - return {handles, std::move(offsets)}; + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); } -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +// Use vector to represent byte data for python binding compatibility. +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); - fa->register_graph_buffers(handles, offsets); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index a2f7e43300002..6be4d4f2b2eb8 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -285,46 +285,52 @@ class CustomAllreduce { int world_size_; bool full_nvlink_; - // below are device pointers RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. std::unordered_map buffers_; Signal* self_sg_; - // stores the registered device pointers from all ranks + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers std::map ipc_handles_; /** - * meta is a pointer to device metadata and temporary buffer for allreduce. + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. * - * There's a total of sizeof(Signal) of prefix before the actual data, - * so meta + 1 points to actual temporary buffer. - * - * note: this class does not own any device memory. Any required buffers - * are passed in from the constructor + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t* handles, - const std::vector& offsets, int rank, - bool full_nvlink = true) + CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) : rank_(rank), - world_size_(offsets.size()), + world_size_(world_size), full_nvlink_(full_nvlink), - self_sg_(meta), + self_sg_(signals[rank]), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal* rank_sg; - if (i != rank_) { - char* handle = open_ipc_handle(&handles[i]); - handle += offsets[i]; - rank_sg = (Signal*)handle; - } else { - rank_sg = self_sg_; - } - sg_.signals[i] = rank_sg; + sg_.signals[i] = signals[i]; } } @@ -341,11 +347,10 @@ class CustomAllreduce { return it->second; } - std::pair, std::vector> - get_graph_buffer_ipc_meta() { + std::pair> get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::vector handles(handle_sz * num_buffers, 0); + std::string handles(handle_sz * num_buffers, static_cast(0)); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; @@ -370,26 +375,22 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector& handles, - const std::vector& offsets, void* self) { + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { - if (i != rank_) { - char* handle = open_ipc_handle(handles[i].data()); - handle += offsets[i]; - data.ptrs[i] = handle; - } else { - data.ptrs[i] = self; - } + data.ptrs[i] = ptrs[i]; } auto d_data = d_rank_data_base_++; CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); - buffers_[self] = d_data; + buffers_[ptrs[rank_]] = d_data; } - // note: when registering graph buffers, we intentionally choose to not + // Note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some // addresses, they will be registered again. This is to account for the remote // possibility of different allocation patterns between ranks. For example, @@ -424,11 +425,13 @@ class CustomAllreduce { } /** - * This is the result after careful grid search. Using 36 blocks give the best - * or close to the best runtime on the devices I tried: A100, A10, A30, T4, - * V100. You'll notice that NCCL kernels also only take a small amount of SMs. - * Not quite sure the underlying reason, but my guess is that too many SMs - * will cause contention on NVLink bus. + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index 376687e91cfda..b59ea40d980f4 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); - std::vector offsets(nRanks, 0); - vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, - offsets, myRank); + vllm::Signal* ipc_ptrs[8]; + for (int i = 0; i < nRanks; i++) { + if (i == myRank) + ipc_ptrs[i] = buffer; + else + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i], + cudaIpcMemLazyEnablePeerAccess)); + } + vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks); auto* self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { - std::vector handles; - handles.reserve(nRanks); + void* data[8]; for (int i = 0; i < nRanks; i++) { - char* begin = (char*)&data_handles[i]; - char* end = (char*)&data_handles[i + 1]; - handles.emplace_back(begin, end); + data[i] = + ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); } - std::vector offsets(nRanks, - sizeof(vllm::Signal) + data_size * sizeof(T)); - fa.register_buffer(handles, offsets, self_data); + fa.register_buffer(data); } double* ground_truth; diff --git a/csrc/ops.h b/csrc/ops.h index c50eb39a3dacc..e0775ee1891df 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -199,20 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, #ifndef USE_ROCM using fptr_t = int64_t; -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, - bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, - torch::Tensor& out); +fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets); -std::tuple> get_graph_buffer_ipc_meta( - fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, const std::vector>& offsets); #endif diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b8185c24d5628..971a45d50ffa4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def( - "init_custom_ar(Tensor meta, Tensor rank_data, " - "str[] handles, int[] offsets, int rank, " - "bool full_nvlink) -> int"); + "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " + "int rank, bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); - custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); - custom_ar.def( - "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " - "()"); - custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " + "int reg_buffer_sz_bytes) -> ()"); + custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); custom_ar.def("dispose", &dispose); custom_ar.def("meta_size", &meta_size); - custom_ar.def( - "register_buffer(int fa, Tensor t, str[] handles, " - "int[] offsets) -> ()"); - custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); - + custom_ar.def("register_buffer", ®ister_buffer); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("register_graph_buffers", ®ister_graph_buffers); } diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 95435e753058a..86ca1948ef94a 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce_unreg(out) + out = fa.all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce_unreg(out) + out = fa.all_reduce(out, registered=False) torch.testing.assert_close(out, inp * (tp_size**num_communication)) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index efd6beee865c2..adc44474aa4c1 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -196,8 +196,8 @@ def is_cross_device_reduce_1stage(op_name: str): def is_cross_device_reduce_2stage(op_name: str): return "cross_device_reduce_2stage" in op_name - def is_custom_ar_all_reduce_unreg(op_name: str): - return "_C_custom_ar::all_reduce_unreg" in op_name + def is_custom_ar_all_reduce(op_name: str): + return "_C_custom_ar::all_reduce" in op_name def is_reduce_kernel(op_name: str): return "reduce_kernel" in op_name @@ -246,9 +246,9 @@ def is_reduce_kernel(op_name: str): filter(lambda x: is_cross_device_reduce_2stage(x), ops)) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) - custom_ar_all_reduce_unreg_ops = list( - filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops)) - ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops)) + custom_ar_all_reduce_ops = list( + filter(lambda x: is_custom_ar_all_reduce(x), ops)) + ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) @@ -289,21 +289,21 @@ def is_reduce_kernel(op_name: str): if len(cross_device_reduce_2stage_ops): trace_df['cross_device_reduce_2stage_ops'] = trace_df[ cross_device_reduce_2stage_ops].agg("sum", axis=1) - if len(custom_ar_all_reduce_unreg_ops): - trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[ - custom_ar_all_reduce_unreg_ops].agg("sum", axis=1) + if len(custom_ar_all_reduce_ops): + trace_df['custom_ar_all_reduce_ops'] = trace_df[ + custom_ar_all_reduce_ops].agg("sum", axis=1) if len(reduce_kernel_ops): trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", axis=1) - trace_df.drop( - attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + - mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + - nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + - cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops + - reduce_kernel_ops, - axis=1, - inplace=True) + trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + vocab_embed_ops + mem_ops + elementwise_ops + + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + + nccl_other_ops + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops + + reduce_kernel_ops, + axis=1, + inplace=True) return trace_df diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 682e08db99fa9..767d45ede7e87 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -912,20 +912,16 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar -def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, - handles: List[str], offsets: List[int], rank: int, - full_nvlink: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, - offsets, rank, full_nvlink) +def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, + rank: int, full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, + full_nvlink) -def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) - - -def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, - out: torch.Tensor) -> None: - torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, + reg_buffer_sz_bytes: int) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, + reg_buffer_sz_bytes) def dispose(fa: int) -> None: @@ -936,16 +932,15 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, t: torch.Tensor, handles: List[str], - offsets: List[int]) -> None: - return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) +def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: List[str], +def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 3b5d92561cf25..62929dc0feaaf 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,6 +1,6 @@ import ctypes from contextlib import contextmanager -from typing import Any, List, Optional, Union +from typing import List, Optional, Union import torch import torch.distributed as dist @@ -147,18 +147,14 @@ def __init__(self, return self.disabled = False - # buffers memory are owned by this Python class and passed to C++ - # meta data composes of two parts: meta data for synchronization - # (256 bytes) and a temporary buffer for storing intermediate - # allreduce results. - self.meta = torch.zeros(ops.meta_size() + max_size, - dtype=torch.uint8, - device=self.device) + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, + group=group) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, - dtype=torch.uint8, - device=self.device) + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -170,16 +166,19 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size - handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, - offsets, rank, self.full_nvlink) - self.register_buffer(self.buffer) + self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, + self.full_nvlink) + ops.register_buffer(self._ptr, self.buffer_ptrs) @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) @@ -220,60 +219,24 @@ def capture(self): if not self.disabled: self.register_graph_buffers() - def _get_ipc_meta(self, inp: torch.Tensor): - data = inp.untyped_storage()._share_cuda_() - handle = data[1] - # https://github.com/pytorch/pytorch/pull/130890 changes - # the binary format of the ipc handle - # it starts from pytorch 2.5 - if len(handle) > 64: - assert len(handle) == 66 - # only support SHAREABLE_HANDLE_VERSION = 1 - assert int(handle[0]) == 1 - # only support SHAREABLE_CUDA_MALLOC = 'c' - assert handle[1] == ord("c") - handle = handle[2:] - # TODO: support expandable segment - shard_data = ( - handle, # ipc handle to base ptr - data[3], # offset of base ptr - ) - return self._gather_ipc_meta(shard_data) - - def _gather_ipc_meta(self, shard_data): - # Note: don't use `[[None]] * self.world_size` here - # because it will create a list of the same reference - all_data: List[Optional[Any]] = [[None] - for i in range(self.world_size)] - all_data[self.rank][0] = shard_data - - ranks = dist.get_process_group_ranks(group=self.group) - ranks.sort() + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) for i, rank in enumerate(ranks): dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") - - # we cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - - handles = [] - offsets = [] - for i in range(len(all_data)): - handles.append(all_data[i][0][0]) # type: ignore - offsets.append(all_data[i][0][1]) # type: ignore - return handles, offsets - - def register_buffer(self, inp: torch.Tensor): - handles, offsets = self._get_ipc_meta(inp) - ops.register_buffer(self._ptr, inp, handles, offsets) - - def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) - logger.info("Registering %d cuda graph addresses", len(offset)) + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): @@ -291,45 +254,50 @@ def should_custom_ar(self, inp: torch.Tensor): return inp_size < self.max_size return False - # all reduce, assuming inp tensor is IPC registered with register_buffer, - # or, in the context of cuda graphs, register_graph_buffers - def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): - if out is None: - out = torch.empty_like(inp) - ops.all_reduce_reg(self._ptr, inp, out) - return out - - # all reduce, assuming inp tensor is NOT IPC registered - def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ if out is None: out = torch.empty_like(inp) - ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], + self.max_size) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - # when custom allreduce is disabled, this will be None + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce_reg(input) + return self.all_reduce(input, registered=True) else: - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. return torch.empty_like(input) else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - return self.all_reduce_unreg(input) - - return None + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close() From e036e527a08fbf00ba725b12c9ebff6cd9bfab52 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 7 Nov 2024 02:54:16 -0500 Subject: [PATCH 06/13] [CI/Build] Improve mypy + python version matrix (#10041) Signed-off-by: Russell Bryant --- .github/workflows/mypy.yaml | 2 +- pyproject.toml | 4 +--- tools/mypy.sh | 5 +++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 28d2e5fb8dbd9..fbee6bb03fc8e 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -43,4 +43,4 @@ jobs: - name: Mypy run: | echo "::add-matcher::.github/workflows/matchers/mypy.json" - tools/mypy.sh 1 + tools/mypy.sh 1 ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index 1aebc543a733a..bae8645502dea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,14 +55,12 @@ ignore = [ ] [tool.mypy] -python_version = "3.9" - ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" # After fixing type errors resulting from follow_imports: "skip" -> "silent", -# move the directory here and remove it from format.sh and mypy.yaml +# move the directory here and remove it from tools/mypy.sh files = [ "vllm/*.py", "vllm/adapter_commons", diff --git a/tools/mypy.sh b/tools/mypy.sh index 14b0976a27da5..7e8f7d402cdd5 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -1,6 +1,7 @@ #!/bin/bash CI=${1:-0} +PYTHON_VERSION=${2:-3.9} if [ $CI -eq 1 ]; then set -e @@ -9,10 +10,10 @@ fi run_mypy() { echo "Running mypy on $1" if [ $CI -eq 1 ] && [ -z "$1" ]; then - mypy "$@" + mypy --python-version "${PYTHON_VERSION}" "$@" return fi - mypy --follow-imports skip "$@" + mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@" } run_mypy # Note that this is less strict than CI From aa9078fa035abfac54179cbdca8b741e49c8cd0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1via=20B=C3=A9o?= <119421251+flaviabeo@users.noreply.github.com> Date: Thu, 7 Nov 2024 05:42:40 -0300 Subject: [PATCH 07/13] Adds method to read the pooling types from model's files (#9506) Signed-off-by: Flavia Beo Signed-off-by: Max de Bayser Co-authored-by: Max de Bayser --- examples/fp8/quantizer/quantize.py | 4 +- tests/engine/test_arg_utils.py | 7 + .../test_model_load_with_params.py | 50 ++++++ tests/test_config.py | 72 ++++++++ tests/utils.py | 14 +- vllm/config.py | 28 ++- vllm/engine/arg_utils.py | 3 +- vllm/model_executor/layers/pooler.py | 14 +- vllm/transformers_utils/config.py | 170 ++++++++++++++++-- .../tokenizer_group/__init__.py | 5 + 10 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 tests/model_executor/test_model_load_with_params.py diff --git a/examples/fp8/quantizer/quantize.py b/examples/fp8/quantizer/quantize.py index 15f1a06b1219b..d75cc8b3d1cf7 100644 --- a/examples/fp8/quantizer/quantize.py +++ b/examples/fp8/quantizer/quantize.py @@ -230,7 +230,7 @@ def calibrate_loop(): def main(args): if not torch.cuda.is_available(): - raise EnvironmentError("GPU is required for inference.") + raise OSError("GPU is required for inference.") random.seed(RAND_SEED) np.random.seed(RAND_SEED) @@ -314,7 +314,7 @@ def main(args): # Workaround for wo quantization if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json", 'r') as f: + with open(f"{export_path}/config.json") as f: tensorrt_llm_config = json.load(f) if args.qformat == "int8_wo": tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index f7dc167fea6e4..e92e2588d01cb 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -30,6 +30,13 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_valid_pooling_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args(["--pooling-type=MEAN"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.pooling_type == 'MEAN' + + @pytest.mark.parametrize( ("arg"), [ diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py new file mode 100644 index 0000000000000..7e5e2780d3916 --- /dev/null +++ b/tests/model_executor/test_model_load_with_params.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from vllm.model_executor.layers.pooler import PoolingType +from vllm.model_executor.models.bert import BertEmbeddingModel +from vllm.platforms import current_platform + +MAX_MODEL_LEN = 128 +MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") +REVISION = os.environ.get("REVISION", "main") + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_model_loading_with_params(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ + with vllm_runner(model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN) as model: + output = model.encode("Write a short story about a robot that" + " dreams for the first time.\n") + + model_config = model.model.llm_engine.model_config + + model_tokenizer = model.model.llm_engine.tokenizer + + # asserts on the bert model config file + assert model_config.encoder_config["max_seq_length"] == 512 + assert model_config.encoder_config["do_lower_case"] + + # asserts on the pooling config files + assert model_config.pooler_config.pooling_type == PoolingType.CLS.name + assert model_config.pooler_config.pooling_norm + + # asserts on the tokenizer loaded + assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.tokenizer_config["do_lower_case"] + assert model_tokenizer.tokenizer.model_max_length == 512 + + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert isinstance(model, BertEmbeddingModel) + assert model._pooler.pooling_type == PoolingType.CLS + assert model._pooler.normalize + # assert output + assert output diff --git a/tests/test_config.py b/tests/test_config.py index 5211049bf0011..66bdb883657c5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,8 @@ import pytest from vllm.config import ModelConfig +from vllm.model_executor.layers.pooler import PoolingType +from vllm.platforms import current_platform @pytest.mark.parametrize(("model_id", "expected_task"), [ @@ -102,6 +104,76 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_pooling_config(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" + minilm_model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type=None, + pooling_norm=None, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) + + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_pooling_config_from_args(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" + minilm_model_config = ModelConfig(model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None) + + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type='CLS', + pooling_norm=True, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) + + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.CLS.name + + +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +def test_get_bert_tokenization_sentence_transformer_config(): + bge_model_config = ModelConfig( + model="BAAI/bge-base-en-v1.5", + task="auto", + tokenizer="BAAI/bge-base-en-v1.5", + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + bert_bge_model_config = bge_model_config._get_encoder_config() + + assert bert_bge_model_config["max_seq_length"] == 512 + assert bert_bge_model_config["do_lower_case"] + + def test_rope_customization(): TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 diff --git a/tests/utils.py b/tests/utils.py index 00c7dabe16a7b..a893667e144a6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ import pytest import requests import torch +import torch.nn.functional as F from openai.types.completion import Completion from typing_extensions import ParamSpec @@ -515,13 +516,14 @@ def compare_all_settings(model: str, ref_result = copy.deepcopy(ref_result) compare_result = copy.deepcopy(compare_result) if "embedding" in ref_result and method == "encode": - ref_embedding = torch.tensor(ref_result["embedding"]) - compare_embedding = torch.tensor( - compare_result["embedding"]) - mse = ((ref_embedding - compare_embedding)**2).mean() - assert mse < 1e-6, ( + sim = F.cosine_similarity( + torch.tensor(ref_result["embedding"]), + torch.tensor(compare_result["embedding"]), + dim=0, + ) + assert sim >= 0.999, ( f"Embedding for {model=} are not the same.\n" - f"mse={mse}\n") + f"cosine_similarity={sim}\n") del ref_result["embedding"] del compare_result["embedding"] assert ref_result == compare_result, ( diff --git a/vllm/config.py b/vllm/config.py index c7fad3a261858..e844a46bf06e6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,10 +13,10 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (ConfigFormat, get_config, - get_hf_image_processor_config, - get_hf_text_config, - is_encoder_decoder, uses_mrope) +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, print_warning_once) @@ -197,6 +197,7 @@ def __init__( code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) + self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -229,7 +230,8 @@ def __init__( max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len) + spec_target_max_model_len=spec_target_max_model_len, + encoder_config=self.encoder_config) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -273,6 +275,10 @@ def _init_multimodal_config( return None + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + def _init_pooler_config( self, pooling_type: Optional[str] = None, @@ -282,6 +288,14 @@ def _init_pooler_config( pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: if self.task == "embedding": + pooling_config = get_pooling_config(self.model, self.revision) + if pooling_config is not None: + # override if user does not + # specifies pooling_type and/or pooling_norm + if pooling_type is None: + pooling_type = pooling_config["pooling_type"] + if pooling_norm is None: + pooling_norm = pooling_config["normalize"] return PoolerConfig( pooling_type=pooling_type, pooling_norm=pooling_norm, @@ -1795,6 +1809,7 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1877,6 +1892,9 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b556c0eed3776..8c5b442e9f624 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -863,7 +864,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--pooling-type', - choices=['LAST', 'ALL', 'CLS', 'STEP'], + choices=[pt.name for pt in PoolingType], default=None, help='Used to configure the pooling method in the embedding model.' ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 1c9772b41cbef..024badbc17b96 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -16,6 +16,7 @@ class PoolingType(IntEnum): ALL = 1 CLS = 2 STEP = 3 + MEAN = 4 class Pooler(nn.Module): @@ -27,7 +28,7 @@ class Pooler(nn.Module): 3. Returns structured results as `PoolerOutput`. Attributes: - pooling_type: The type of pooling to use (LAST, ALL, CLS). + pooling_type: The type of pooling to use. normalize: Whether to normalize the pooled data. """ @@ -97,6 +98,17 @@ def forward( for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.MEAN: + # Calculate mean pooling + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + pooled_data = ( + cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) elif self.pooling_type == PoolingType.STEP: if self.returned_token_ids is not None and len( self.returned_token_ids) > 0: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 415d8bf7cc2bb..6b38ee31c2657 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,6 +6,9 @@ import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, try_to_load_from_cache) +from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -213,7 +216,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision) + config = load_params_config(model, revision, token=kwargs.get("token")) else: raise ValueError(f"Unsupported config format: {config_format}") @@ -243,6 +246,158 @@ def get_config( return config +def get_hf_file_to_dict(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + - token (str): The Hugging Face authentication token. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + file_path = Path(model) / file_name + + if file_or_path_exists(model=model, + config_name=file_name, + revision=revision, + token=token): + + if not file_path.is_file(): + try: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision) + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", + e) + return None + file_path = Path(hf_hub_file) + + with open(file_path) as file: + return json.load(file) + return None + + +def get_pooling_config(model: str, + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + This function gets the pooling and normalize + config from the model - only applies to + sentence-transformers models. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision, + token) + + if modules_dict is None: + return None + + pooling = next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling"), + None) + normalize = bool( + next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize"), + False)) + + if pooling: + + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision, + token) + pooling_type_name = next( + (item for item, val in pooling_dict.items() if val is True), None) + + if pooling_type_name is not None: + pooling_type_name = get_pooling_config_name(pooling_type_name) + + return {"pooling_type": pooling_type_name, "normalize": normalize} + + return None + + +def get_pooling_config_name(pooling_name: str) -> Union[str, None]: + if "pooling_mode_" in pooling_name: + pooling_name = pooling_name.replace("pooling_mode_", "") + + if "_" in pooling_name: + pooling_name = pooling_name.split("_")[0] + + if "lasttoken" in pooling_name: + pooling_name = "last" + + supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + pooling_type_name = pooling_name.upper() + + try: + if pooling_type_name in supported_pooling_types: + return pooling_type_name + except NotImplementedError as e: + logger.debug("Pooling type not supported", e) + return None + return None + + +def get_sentence_transformer_tokenizer_config(model: str, + revision: Optional[str] = 'main', + token: Optional[str] = None): + """ + Returns the tokenization configuration dictionary for a + given Sentence Transformer BERT model. + + Parameters: + - model (str): The name of the Sentence Transformer + BERT model. + - revision (str, optional): The revision of the m + odel to use. Defaults to 'main'. + - token (str): A Hugging Face access token. + + Returns: + - dict: A dictionary containing the configuration parameters + for the Sentence Transformer BERT model. + """ + for config_name in [ + "sentence_bert_config.json", + "sentence_roberta_config.json", + "sentence_distilbert_config.json", + "sentence_camembert_config.json", + "sentence_albert_config.json", + "sentence_xlm-roberta_config.json", + "sentence_xlnet_config.json", + ]: + encoder_dict = get_hf_file_to_dict(config_name, model, revision, token) + if encoder_dict: + break + + if not encoder_dict: + return None + + if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")): + return encoder_dict + return None + + def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: """Try to register HF model configuration class to serialize by value @@ -305,20 +460,15 @@ def _reduce_modelconfig(mc: ModelConfig): exc_info=e) -def load_params_config(model, revision) -> PretrainedConfig: +def load_params_config(model: Union[str, Path], + revision: Optional[str], + token: Optional[str] = None) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format config_file_name = "params.json" - config_path = Path(model) / config_file_name - - if not config_path.is_file(): - config_path = Path( - hf_hub_download(model, config_file_name, revision=revision)) - - with open(config_path) as file: - config_dict = json.load(file) + config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) config_mapping = { "dim": "hidden_size", diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9a4149251d747..6a114b513f382 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,6 +25,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) + if (model_config.encoder_config is not None + and "do_lower_case" in model_config.encoder_config): + init_kwargs["do_lower_case"] = model_config.encoder_config[ + "do_lower_case"] + return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) From 0dfba97b42032987fd6bd3d304ac22dd314c89b1 Mon Sep 17 00:00:00 2001 From: Lei Yang Date: Thu, 7 Nov 2024 17:07:19 +0800 Subject: [PATCH 08/13] [Frontend] Fix multiple values for keyword argument error (#10075) (#10076) Signed-off-by: Lei --- vllm/entrypoints/openai/serving_engine.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e7aeac8f8c018..e31dc2ced61fb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -443,29 +443,28 @@ async def _preprocess_chat( tokenizer, ) + _chat_template_kwargs: Dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + request_prompt: Union[str, List[int]] is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) if is_mistral_tokenizer: request_prompt = apply_mistral_chat_template( tokenizer, messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tool_dicts, - documents=documents, - **(chat_template_kwargs or {}), + **_chat_template_kwargs, ) else: request_prompt = apply_hf_chat_template( tokenizer, conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - continue_final_message=continue_final_message, - tools=tool_dicts, - documents=documents, - **(chat_template_kwargs or {}), + **_chat_template_kwargs, ) mm_data = await mm_data_future From a6f332d0d9ac3e795949da7703f203b6b1a42797 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 7 Nov 2024 18:42:50 +0800 Subject: [PATCH 09/13] [Hardware][CPU][bugfix] Fix half dtype support on AVX2-only target (#10108) Signed-off-by: jiang1.li --- cmake/cpu_extension.cmake | 2 +- csrc/cpu/cpu_types_x86.hpp | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 776a0bb11ae64..5912c5c02ede7 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -93,7 +93,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.5.3 + GIT_TAG v3.6 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 12d5757b495be..4bb4eb0f491ac 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -432,6 +432,16 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const FP32Vec8 &data) : reg_low(data.reg), reg_high(data.reg) {} + explicit FP32Vec16(const FP16Vec16 &v) { + __m128i low = _mm256_extractf128_si256(v.reg, 0); + __m128i high = _mm256_extractf128_si256(v.reg, 1); + + reg_low = _mm256_cvtph_ps(low); + reg_high = _mm256_cvtph_ps(high); + } + + explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const BF16Vec16 &v) { __m128i low = _mm256_extractf128_si256(v.reg, 0); __m128i high = _mm256_extractf128_si256(v.reg, 1); From 999df95b4eefb920cd3539a7fa3a21b2911f3650 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Thu, 7 Nov 2024 18:50:44 +0800 Subject: [PATCH 10/13] [Bugfix] Make image processor respect `mm_processor_kwargs` for Qwen2-VL (#10112) Signed-off-by: Jiahao Li --- vllm/model_executor/models/qwen2_vl.py | 33 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index af263262bd239..0e820cf123139 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,8 +22,8 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import partial -from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, - Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, Type, TypedDict, Union) import torch import torch.nn as nn @@ -558,6 +558,17 @@ def forward( # === Vision input helpers === # +def get_mm_processor_kwargs( + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None) -> Dict[str, int]: + mm_processor_kwargs = {} + if min_pixels: + mm_processor_kwargs["min_pixels"] = min_pixels + if max_pixels: + mm_processor_kwargs["max_pixels"] = max_pixels + return mm_processor_kwargs + + def mm_input_mapper_for_qwen2_vl( ctx: InputContext, data: MultiModalData[object], @@ -575,12 +586,8 @@ def mm_input_mapper_for_qwen2_vl( model_config = ctx.model_config # Handle mm processor kwargs; we pass these at creation time # because preprocess() in transformers doesn't expose them - mm_processor_kwargs = {} - if min_pixels: - mm_processor_kwargs["min_pixels"] = min_pixels - if max_pixels: - mm_processor_kwargs["max_pixels"] = max_pixels - + mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels) image_processor = cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code, @@ -683,7 +690,10 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, *, min_pixels=None, max_pixels=None) -> int: - image_processor = cached_get_image_processor(ctx.model_config.model) + mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels) + image_processor = cached_get_image_processor(ctx.model_config.model, + **mm_processor_kwargs) max_resized_height, max_resized_width, max_llm_image_tokens = \ _get_max_image_info(image_processor, data_type_key=data_type_key, mm_count=1, min_pixels=min_pixels, @@ -705,7 +715,10 @@ def dummy_data_for_qwen2_vl( min_pixels: Optional[int] = None, max_pixels: Optional[int] = None ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: - image_processor = cached_get_image_processor(ctx.model_config.model) + mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels) + image_processor = cached_get_image_processor(ctx.model_config.model, + **mm_processor_kwargs) num_images = mm_counts["image"] max_resized_height, max_resized_width, max_llm_image_tokens = \ From a62bc0109c3864b9dc770dc637e3acd332c730ea Mon Sep 17 00:00:00 2001 From: Atlas <163425173+spliii@users.noreply.github.com> Date: Thu, 7 Nov 2024 19:20:30 +0800 Subject: [PATCH 11/13] [Misc] Add Gamma-Distribution Request Generation Support for Serving Benchmark. (#10105) Signed-off-by: Mozhou Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- benchmarks/benchmark_serving.py | 57 ++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index ff06622628219..bdb8ea8e2a5dc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -297,8 +297,33 @@ def sample_random_requests( async def get_request( input_requests: List[Tuple[str, int, int]], request_rate: float, + burstiness: float = 1.0, ) -> AsyncGenerator[Tuple[str, int, int], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + for request in input_requests: yield request @@ -306,8 +331,9 @@ async def get_request( # If the request rate is infinity, then we don't need to wait. continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) # The next request will be sent after the interval. await asyncio.sleep(interval) @@ -426,6 +452,7 @@ async def benchmark( logprobs: Optional[int], best_of: int, request_rate: float, + burstiness: float, disable_tqdm: bool, profile: bool, selected_percentile_metrics: List[str], @@ -480,7 +507,13 @@ async def benchmark( if profile_output.success: print("Profiler started") + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -502,7 +535,7 @@ async def limited_request_func(request_func_input, pbar): benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): + async for request in get_request(input_requests, request_rate, burstiness): prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput(model=model_id, prompt=prompt, @@ -769,6 +802,7 @@ def main(args: argparse.Namespace): logprobs=args.logprobs, best_of=args.best_of, request_rate=args.request_rate, + burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), @@ -807,6 +841,7 @@ def main(args: argparse.Namespace): # Traffic result_json["request_rate"] = ( args.request_rate if args.request_rate < float("inf") else "inf") + result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency # Merge with benchmark result @@ -922,8 +957,20 @@ def main(args: argparse.Namespace): default=float("inf"), help="Number of requests per second. If this is inf, " "then all the requests are sent at time 0. " - "Otherwise, we use Poisson process to synthesize " - "the request arrival times.", + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument( From ae62fd17c0023f7ec363c1141787b8c017937c44 Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Thu, 7 Nov 2024 12:09:02 -0300 Subject: [PATCH 12/13] [Frontend] Tool calling parser for Granite 3.0 models (#9027) Signed-off-by: Max de Bayser --- .../serving/openai_compatible_server.md | 44 ++-- examples/tool_chat_template_granite.jinja | 40 ++++ tests/tool_use/conftest.py | 6 + tests/tool_use/utils.py | 37 +-- .../openai/tool_parsers/__init__.py | 5 +- .../tool_parsers/granite_tool_parser.py | 215 ++++++++++++++++++ 6 files changed, 314 insertions(+), 33 deletions(-) create mode 100644 examples/tool_chat_template_granite.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 0b5f75caf2475..a196f8b1e574e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -160,14 +160,7 @@ this, unless explicitly specified. :func: create_parser_for_docs :prog: vllm serve ``` -## Tool Calling in the Chat Completion API -### Named Function Calling -vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is -enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a -high-quality one. -To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and -specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. ### Config file @@ -196,12 +189,22 @@ The order of priorities is `command line > config file values > defaults`. --- ## Tool calling in the chat completion API -vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. + +vLLM supports named function calling and `auto` tool choice in the chat completion API. The `tool_choice` options `required` is **not yet supported** but on the roadmap. It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. + +### Named Function Calling +vLLM supports named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. + ### Automatic Function Calling To enable this feature, you should set the following flags: @@ -275,6 +278,21 @@ it works better with vLLM. Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +#### IBM Granite + +Supported models: +* `ibm-granite/granite-3.0-8b-instruct` + +Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` + +`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + +* `ibm-granite/granite-20b-functioncalling` + +Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` + +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + #### InternLM Models (`internlm`) @@ -297,16 +315,6 @@ AI21's Jamba-1.5 models are supported. Flags: `--tool-call-parser jamba` -#### IBM Granite (`granite-20b-fc`) - -Supported models: -* `ibm-granite/granite-20b-functioncalling` - -Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` - -The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. - - ### How to write a tool parser plugin A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja new file mode 100644 index 0000000000000..2cc19e77188dc --- /dev/null +++ b/examples/tool_chat_template_granite.jinja @@ -0,0 +1,40 @@ +{%- if tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|> +' }} + {%- for tool in tools %} + {{- tool | tojson(indent=4) }} + {%- if not loop.last %} + {{- ' + +' }} + {%- endif %} + {%- endfor %} + {{- '<|end_of_text|> +' }} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'user' %} + {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {% for tc in message.tool_calls %} + {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} + {% endfor %} + {{- '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %} + {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- endif %} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- endif %} +{%- endfor %} diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index ab6a29eba1b3f..294acf202a232 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -3,6 +3,7 @@ from huggingface_hub import snapshot_download from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform from .utils import ARGS, CONFIGS, ServerConfig @@ -11,6 +12,11 @@ @pytest.fixture(scope="session", params=CONFIGS.keys()) def server_config(request): config = CONFIGS[request.param] + + if current_platform.is_rocm() and not config.get("supports_rocm", True): + pytest.skip("The {} model can't be tested on the ROCm platform".format( + config["model"])) + # download model and tokenizer using transformers snapshot_download(config["model"]) yield CONFIGS[request.param] diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index d9ee0b1d54b0a..576555b368afe 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -13,6 +13,7 @@ class ServerConfig(TypedDict, total=False): arguments: List[str] system_prompt: Optional[str] supports_parallel: Optional[bool] + supports_rocm: Optional[bool] def patch_system_prompt(messages: List[Dict[str, Any]], @@ -36,7 +37,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] CONFIGS: Dict[str, ServerConfig] = { "hermes": { @@ -88,18 +89,28 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally." }, - ## FIXME: temporary disabled due to lack of hardware specification - ## for individual runs - #"granite20b": { - # "model": - # "ibm-granite/granite-20b-functioncalling", - # "arguments": [ - # "--tool-call-parser", "granite-20b-fc", "--chat-template", - # str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja") - # ], - # "supports_parallel": - # False, - #}, + "granite20b": { + "model": + "mbayser/granite-20b-functioncalling-FP8-KV", + "arguments": [ + "--tool-call-parser", "granite-20b-fc", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_granite_20b_fc.jinja"), + "--max_num_seqs", "1", "--enforce-eager", "--cpu-offload-gb", "20" + ], + "supports_parallel": + False, + "supports_rocm": + False, + }, + "granite8b": { + "model": + "ibm-granite/granite-3.0-8b-instruct", + "arguments": [ + "--tool-call-parser", "granite", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") + ], + }, "internlm": { "model": "internlm/internlm2_5-7b-chat", diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 1b299ce655570..2187862e8380b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,5 +1,6 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .granite_20b_fc_tool_parser import Granite20bFCToolParser +from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser @@ -8,6 +9,6 @@ __all__ = [ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", - "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", - "Llama3JsonToolParser", "JambaToolParser" + "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", + "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py new file mode 100644 index 0000000000000..b5854ca39ab47 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -0,0 +1,215 @@ +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite") +class GraniteToolParser(ToolParser): + """ + Tool call parser for the granite 3.0 models. Intended + for use with the examples/tool_chat_template_granite.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + stripped = model_output.strip() + if not stripped or stripped[0] != '[': + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + raw_function_calls = json.loads(stripped) + if not isinstance(raw_function_calls, list): + raise Exception( + f"Expected dict or list, got {type(raw_function_calls)}") + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]), + ), + ) for function_call in raw_function_calls + ] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + start_idx = consume_space(0, current_text) + if not current_text or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = None + is_complete = None + try: + tool_calls, end_idx = partial_json_loads( + current_text[start_idx:], flags) + if type(tool_calls) is list: + tool_call_arr = tool_calls + else: + return DeltaMessage(content=delta_text) + + is_complete = [True] * len(tool_calls) + if not is_complete_json( + current_text[start_idx:start_idx + end_idx]): + is_complete[-1] = False + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if not tool_call_arr: + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + delta = None + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None From 2f0ed5e15a796e77900bc1f406f168f63f0e3af0 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 7 Nov 2024 10:17:52 -0500 Subject: [PATCH 13/13] Update Dockerfile --- Dockerfile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 922183e21444d..56e213ad87ab2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,8 +48,6 @@ COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install hf_transfer # cuda arch list used by torch @@ -193,6 +191,11 @@ ADD . /vllm-workspace/ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt +# enable fast downloads from hf (for testing) +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install hf_transfer +ENV HF_HUB_ENABLE_HF_TRANSFER 1 + # doc requires source code # we hide them inside `test_docs/` , so that this source code # will not be imported by other tests