diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 64cf6ef8fc19d..8217bc3ba3ded 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1 sphinx-copybutton==0.5.2 myst-parser==3.0.1 sphinx-argparse==0.4.0 +sphinx-design==0.6.1 sphinx-togglebutton==0.3.2 msgspec cloudpickle diff --git a/docs/source/api/multimodal/index.md b/docs/source/api/multimodal/index.md index 51e24795a34cf..14efdb506d76f 100644 --- a/docs/source/api/multimodal/index.md +++ b/docs/source/api/multimodal/index.md @@ -7,7 +7,7 @@ vLLM provides experimental support for multi-modal models through the {mod}`vllm Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models) via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`. -Looking to add your own multi-modal model? Please follow the instructions listed [here](#enabling-multimodal-inputs). +Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal). ## Module Contents diff --git a/docs/source/api/multimodal/inputs.md b/docs/source/api/multimodal/inputs.md index 3d89666113229..76b2fb95a5009 100644 --- a/docs/source/api/multimodal/inputs.md +++ b/docs/source/api/multimodal/inputs.md @@ -3,7 +3,7 @@ ## User-facing inputs ```{eval-rst} -.. autodata:: vllm.multimodal.MultiModalDataDict +.. autodata:: vllm.multimodal.inputs.MultiModalDataDict ``` ## Internal data structures diff --git a/docs/source/conf.py b/docs/source/conf.py index 1ce11fe057071..bff0141ffbce8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -43,6 +43,7 @@ "sphinx.ext.autosummary", "myst_parser", "sphinxarg.ext", + "sphinx_design", "sphinx_togglebutton", ] myst_enable_extensions = [ diff --git a/docs/source/contributing/model/index.md b/docs/source/contributing/model/index.md index a2d601c83cf47..245e13b795ec4 100644 --- a/docs/source/contributing/model/index.md +++ b/docs/source/contributing/model/index.md @@ -2,7 +2,7 @@ # Adding a New Model -This section provides more information on how to integrate a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM. +This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM. ```{toctree} :caption: Contents diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index e5dcd1223b361..76ab73e43d24b 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -1,6 +1,6 @@ -(enabling-multimodal-inputs)= +(supports-multimodal)= -# Enabling Multimodal Inputs +# Multi-Modal Support This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs). @@ -37,103 +37,355 @@ Further update the model as follows: ) -> SamplerOutput: ``` -## 2. Register input mappers +## 2. Specify processing information -For each modality type that the model accepts as input, decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_input_mapper `. -This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in {meth}`~torch.nn.Module.forward`. +Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo` +to provide basic information related to HF processing. -```diff - from vllm.model_executor.models.interfaces import SupportsMultiModal -+ from vllm.multimodal import MULTIMODAL_REGISTRY +### Maximum number of input items -+ @MULTIMODAL_REGISTRY.register_image_input_mapper() - class YourModelForImage2Seq(nn.Module, SupportsMultiModal): +You need to override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_supported_mm_limits` +to return the maximum number of input items for each modality supported by the model. + +For example, if the model supports any number of images but only one video per prompt: + +```python +def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} ``` -A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. +### Maximum number of placeholder feature tokens + +Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item` +to return the maximum number of placeholder feature tokens per input item for each modality. + +When calling the model, the output embeddings from the visual encoder are assigned to the input positions +containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal +to the size of the output embeddings. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA +:sync: llava + +Looking at the code of HF's `LlavaForConditionalGeneration`: + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544 +n_image_tokens = (input_ids == self.config.image_token_index).sum().item() +n_image_features = image_features.shape[0] * image_features.shape[1] + +if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) +special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) +) +image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) +inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) +``` -```{seealso} -[Input Processing Pipeline](#input-processing-pipeline) +The number of placeholder feature tokens per image is `image_features.shape[1]`. +`image_features` is calculated inside the `get_image_features` method: + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300 +image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + +selected_image_feature = image_outputs.hidden_states[vision_feature_layer] +if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] +elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature +else: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") +image_features = self.multi_modal_projector(selected_image_feature) +return image_features ``` -## 3. Register maximum number of multi-modal tokens +We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower +(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model). +Moreover, we only need the sequence length (the second dimension of the tensor) to get `image_features.shape[1]`. +The sequence length is determined by the initial hidden states in `CLIPVisionTransformer` since the attention +mechanism doesn't change the sequence length of the output hidden states. + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L1094-L1102 +hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) +hidden_states = self.pre_layrnorm(hidden_states) + +encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, +) +``` -For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item -and register it via {meth}`INPUT_REGISTRY.register_dummy_data `. +To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`: + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257 +target_dtype = self.patch_embedding.weight.dtype +patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] +patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + +class_embeds = self.class_embedding.expand(batch_size, 1, -1) +embeddings = torch.cat([class_embeds, patch_embeds], dim=1) +if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) +else: + embeddings = embeddings + self.position_embedding(self.position_ids) +return embeddings +``` -```diff - from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsMultiModal - from vllm.multimodal import MULTIMODAL_REGISTRY +We can infer that `embeddings.shape[1] == self.num_positions`, where - @MULTIMODAL_REGISTRY.register_image_input_mapper() -+ @MULTIMODAL_REGISTRY.register_max_image_tokens() - @INPUT_REGISTRY.register_dummy_data() - class YourModelForImage2Seq(nn.Module, SupportsMultiModal): +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L195-L196 +self.num_patches = (self.image_size // self.patch_size) ** 2 +self.num_positions = self.num_patches + 1 ``` -Here are some examples: +Overall, the number of placeholder feature tokens for an image can be calculated as: -- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py) -- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py) +```python +def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, +) -> int: + hf_config = self.get_hf_config() + hf_processor = self.get_hf_processor() -```{seealso} -[Input Processing Pipeline](#input-processing-pipeline) + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + + num_image_tokens = (image_size // patch_size) ** 2 + 1 + if hf_processor.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + return num_image_tokens ``` -## 4. (Optional) Register dummy data +Notice that the number of image tokens doesn't depend on the image width and height. +So, we can calculate the maximum number of image tokens using any image size: -During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. -In such cases, you can define your own dummy data by registering a factory method via {meth}`INPUT_REGISTRY.register_dummy_data `. +```python +def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + width = height = hf_config.image_size + return ImageSize(width=width, height=height) -```diff - from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsMultiModal - from vllm.multimodal import MULTIMODAL_REGISTRY +def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() - @MULTIMODAL_REGISTRY.register_image_input_mapper() - @MULTIMODAL_REGISTRY.register_max_image_tokens() -+ @INPUT_REGISTRY.register_dummy_data() - class YourModelForImage2Seq(nn.Module, SupportsMultiModal): + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) +``` + +And thus, we can override the method as: + +```python +def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} ``` ```{note} -The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step. +Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP. ``` +::: +:::: + +## 3. Specify dummy inputs + +Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for +HF processing as well as memory profiling. + +### For memory profiling + +Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` +to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of +the model so that vLLM can reserve the correct amount of memory for it. + +Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based +on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA +:sync: llava +Making use of the `get_image_size_with_most_features` method implemented in the previous section: + +```python +def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], +) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + hf_config = self.get_hf_config() + target_width, target_height = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) +``` +::: +:::: -Here are some examples: +## 4. Specify processing details -- Image inputs (static feature size): [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py) -- Image inputs (dynamic feature size): [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py) +Afterwards, create a subclass of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` +to fill in the missing details about HF processing. ```{seealso} -[Input Processing Pipeline](#input-processing-pipeline) +[Multi-Modal Data Processing](#mm-processing) ``` -## 5. (Optional) Register input processor +### Multi-modal fields + +Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config` to +return a schema of the tensors outputted by the HF processor that are related to the input multi-modal items. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA +:sync: llava + +Looking at the model's `forward` method: + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L387-L404 +def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, LlavaCausalLMOutputWithPast]: +``` -Sometimes, there is a need to process inputs at the {class}`~vllm.LLMEngine` level before they are passed to the model executor. -This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's {meth}`~torch.nn.Module.forward` call. -You can register input processors via {meth}`INPUT_REGISTRY.register_input_processor `. +The only related keyword argument is `pixel_values` which directly corresponds to input images. +The shape of `pixel_values` is `(N, C, H, W)` where `N` is the number of images. +So, we override the method as follows: + +```python +def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], +) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + ) +``` -```diff - from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsMultiModal - from vllm.multimodal import MULTIMODAL_REGISTRY +```{note} +Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports +pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument. +``` +::: +:::: - @MULTIMODAL_REGISTRY.register_image_input_mapper() - @MULTIMODAL_REGISTRY.register_max_image_tokens() - @INPUT_REGISTRY.register_dummy_data() -+ @INPUT_REGISTRY.register_input_processor() - class YourModelForImage2Seq(nn.Module, SupportsMultiModal): +### Prompt replacements + +Override {class}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to +return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances. + +Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace +operation performed by the HF processor. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA +:sync: llava + +Looking at HF's `LlavaProcessor`: + +```python +# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/processing_llava.py#L167-L170 +prompt_strings = [] +for sample in text: + sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + prompt_strings.append(sample) ``` -A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. -Here are some examples: +It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). +Based on this, we override the method as follows: + +```python +def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, +) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement, + ), + ] +``` +::: +:::: -- Insert static number of image tokens: [LLaVA-1.5 Model](gh-file:vllm/model_executor/models/llava.py) -- Insert dynamic number of image tokens: [LLaVA-NeXT Model](gh-file:vllm/model_executor/models/llava_next.py) +## 5. Register processor-related classes -```{seealso} -[Input Processing Pipeline](#input-processing-pipeline) +After you have defined {class}`~vllm.multimodal.processing.BaseProcessingInfo` (Step 2), +{class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` (Step 3), +and {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` (Step 4), +decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor ` +to register them to the multi-modal registry: + +```diff + from vllm.model_executor.models.interfaces import SupportsMultiModal ++ from vllm.multimodal import MULTIMODAL_REGISTRY + ++ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor, ++ info=YourProcessingInfo, ++ dummy_inputs=YourDummyInputsBuilder) + class YourModelForImage2Seq(nn.Module, SupportsMultiModal): ``` diff --git a/docs/source/contributing/model/registration.md b/docs/source/contributing/model/registration.md index fe5aa94c52896..6a9262669cd29 100644 --- a/docs/source/contributing/model/registration.md +++ b/docs/source/contributing/model/registration.md @@ -48,7 +48,7 @@ ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCaus ```{important} If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. -Read more about that [here](#enabling-multimodal-inputs). +Read more about that [here](#supports-multimodal). ``` ```{note} diff --git a/docs/source/design/input_processing/input_processing_pipeline.md b/docs/source/design/input_processing/input_processing_pipeline.md deleted file mode 100644 index bb16920e3d0c0..0000000000000 --- a/docs/source/design/input_processing/input_processing_pipeline.md +++ /dev/null @@ -1,19 +0,0 @@ -(input-processing-pipeline)= - -# Input Processing Pipeline - -1. Input data is passed to {class}`~vllm.LLMEngine` (or {class}`~vllm.AsyncLLMEngine`). - -2. Tokenize the data if necessary. - -3. Process the inputs using {meth}`INPUT_REGISTRY.process_input `. - - - For example, add placeholder tokens to reserve KV cache for multi-modal embeddings. - -4. Send the processed inputs to {class}`~vllm.executor.executor_base.ExecutorBase`. - -5. Distribute the inputs via {class}`~vllm.worker.worker_base.WorkerBase` to {class}`~vllm.worker.model_runner_base.ModelRunnerBase`. - -6. If the data contains multi-modal data, convert it into keyword arguments using {meth}`MULTIMODAL_REGISTRY.map_input `. - - - For example, convert a {class}`PIL.Image.Image` input to its pixel values for a vision model. diff --git a/docs/source/design/input_processing/model_inputs_index.md b/docs/source/design/input_processing/model_inputs_index.md deleted file mode 100644 index cb415366e5a66..0000000000000 --- a/docs/source/design/input_processing/model_inputs_index.md +++ /dev/null @@ -1,43 +0,0 @@ -(input-processing)= - -# Input Processing - -```{eval-rst} -.. currentmodule:: vllm.inputs -``` - -Each model can override parts of vLLM's [input processing pipeline](#input-processing-pipeline) via -{data}`~vllm.inputs.INPUT_REGISTRY` and {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`. - -Currently, this mechanism is only utilized in [multi-modal](#multi-modality) models for preprocessing multi-modal input -data in addition to input prompt, but it can be extended to text-only language models when needed. - -## Guides - -```{toctree} -:maxdepth: 1 - -input_processing_pipeline -``` - -## Module Contents - -### LLM Engine Inputs - -```{eval-rst} -.. autoclass:: vllm.inputs.DecoderOnlyInputs - :members: - :show-inheritance: -``` - -### Registry - -```{eval-rst} -.. autodata:: vllm.inputs.INPUT_REGISTRY -``` - -```{eval-rst} -.. automodule:: vllm.inputs.registry - :members: - :show-inheritance: -``` diff --git a/docs/source/design/mm_processing.md b/docs/source/design/mm_processing.md new file mode 100644 index 0000000000000..a0d01205e638c --- /dev/null +++ b/docs/source/design/mm_processing.md @@ -0,0 +1,64 @@ +(mm-processing)= + +# Multi-Modal Data Processing + +To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefill) and [prefix caching](#automatic-prefix-caching), we use {class}`~vllm.multimodal.processing.BaseMultiModalProcessor` to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. + +Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`: + +## Prompt Replacement Detection + +One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `` for a single image) with feature placeholder tokens (e.g. `...`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. + +In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt. + +## Tokenized Prompt Inputs + +To enable tokenization in a separate process, we support passing input token IDs alongside multi-modal data. + +### The problem + +Consider that HF processors follow these main steps: + +1. Tokenize the text +2. Process multi-modal inputs +3. Perform prompt replacement + +And we require that: + +- For text + multi-modal inputs, apply all steps 1--3. +- For tokenized + multi-modal inputs, apply only steps 2--3. + +How can we achieve this without rewriting HF processors? We can try to call the HF processor several times on different inputs: + +- For text + multi-modal inputs, simply call the HF processor directly. +- For tokenized + multi-modal inputs, call the processor only on the multi-modal inputs. + +While HF processors support text + multi-modal inputs natively, this is not so for tokenized + multi-modal inputs: an error is thrown if the number of input placeholder tokens do not correspond to the number of multi-modal inputs. + +Moreover, since the tokenized text has not passed through the HF processor, we have to apply Step 3 by ourselves to keep the output tokens and multi-modal data consistent with each other. + +(mm-dummy-text)= + +### Dummy text + +We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. + +(mm-automatic-prompt-replacement)= + +### Automatic prompt replacement + +We address the second issue by implementing model-agnostic code in +{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. + +### Summary + +With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. + +## Processor Output Caching + +Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238). To alleviate this problem, we cache the multi-modal outputs of HF processor to avoid processing the same multi-modal input (e.g. image) again. + +When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. + +Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other. diff --git a/docs/source/index.md b/docs/source/index.md index 356fa4b7fd573..de74276a50fb6 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -145,7 +145,7 @@ design/arch_overview design/huggingface_integration design/plugin_system design/kernel/paged_attention -design/input_processing/model_inputs_index +design/mm_processing design/automatic_prefix_caching design/multiprocessing ``` diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index 7e96ed46f2dcc..a06f121a6899a 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -14,7 +14,7 @@ and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/ch To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`: - `prompt`: The prompt should follow the format that is documented on HuggingFace. -- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.MultiModalDataDict`. +- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`. ### Image diff --git a/vllm/config.py b/vllm/config.py index 13b5390008a35..59b509d5a961e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2124,8 +2124,7 @@ class MultiModalConfig: limit_per_prompt: Mapping[str, int] = field(default_factory=dict) """ - The maximum number of multi-modal input instances allowed per prompt - for each :class:`~vllm.multimodal.MultiModalPlugin`. + The maximum number of input items allowed per prompt for each modality. """ def compute_hash(self) -> str: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index aaeecab7ffde1..a0dd89f69bacd 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -11,9 +11,6 @@ """ The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` to dispatch data processing according to the target model. - -See also: - :ref:`input-processing-pipeline` """ __all__ = [ diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index aad0dfab94a01..4b73ade7af5f0 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -313,9 +313,6 @@ def dummy_data_for_profiling( The model is identified by ``model_config``. - See also: - :ref:`enabling-multimodal-inputs` - Note: This should be called after :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. @@ -384,10 +381,8 @@ def register_input_processor(self, processor: InputProcessor): Register an input processor to a model class. The provided function is invoked on each input to the model. This - happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. - - See also: - :ref:`input-processing-pipeline` + happens before + :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`. """ def wrapper(model_cls: N) -> N: @@ -429,9 +424,6 @@ def process_input(self, model_config: "ModelConfig", Apply an input processor to an instance of model inputs. The model is identified by ``model_config``. - - See also: - :ref:`input-processing-pipeline` """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 343b9322ecc5e..1d7f5d57fa24e 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -8,10 +8,10 @@ MULTIMODAL_REGISTRY = MultiModalRegistry() """ The global :class:`~MultiModalRegistry` is used by model runners to -dispatch data processing according to its modality and the target model. +dispatch data processing according to the target model. See also: - :ref:`input-processing-pipeline` + :ref:`mm-processing` """ __all__ = [ diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 4941fbac963ca..fd3ec7e0ec8ce 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -90,10 +90,6 @@ def register_input_mapper( invoked to transform the data into a dictionary of model inputs. If `None` is provided, then the default input mapper is used instead. - - See also: - - :ref:`input-processing-pipeline` - - :ref:`enabling-multimodal-inputs` """ def wrapper(model_cls: N) -> N: @@ -126,10 +122,6 @@ def map_input( Raises: TypeError: If the data type is not supported. - - See also: - - :ref:`input-processing-pipeline` - - :ref:`enabling-multimodal-inputs` """ # Avoid circular import @@ -186,9 +178,6 @@ def register_max_multimodal_tokens( for a model class. If `None` is provided, then the default calculation is used instead. - - See also: - :ref:`enabling-multimodal-inputs` """ def wrapper(model_cls: N) -> N: @@ -218,9 +207,6 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: If this registry is not applicable to the model, `0` is returned. The model is identified by ``model_config``. - - See also: - :ref:`enabling-multimodal-inputs` """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 8680e4175593b..4b63703585214 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -493,7 +493,8 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]: class MultiModalInputsV2(TypedDict): """ - Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`, + Represents the outputs of + :class:`vllm.multimodal.processing.BaseMultiModalProcessor`, ready to be passed to vLLM internals. """ diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 9eceefb08c93f..804a91da8c889 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -100,8 +100,7 @@ def __getitem__(self, key: "ModelConfig") -> Dict[str, int]: class MultiModalRegistry: """ - A registry that dispatches data processing to the - :class:`~vllm.multimodal.MultiModalPlugin` for each modality. + A registry that dispatches data processing according to the model. """ DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) @@ -367,8 +366,7 @@ def register_processor( invoked to transform the data into a dictionary of model inputs. See also: - - :ref:`input-processing-pipeline` - - :ref:`enabling-multimodal-inputs` + :ref:`mm-processing` """ def wrapper(model_cls: N) -> N: @@ -398,6 +396,9 @@ def _get_model_cls(self, model_config: "ModelConfig"): def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. + + See also: + :ref:`mm-processing` """ return self._get_model_cls(model_config) in self._processor_factories @@ -408,6 +409,9 @@ def create_processor( ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. + + See also: + :ref:`mm-processing` """ model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls]