Skip to content

Commit

Permalink
[Model] PP support for embedding models and update docs (#9090)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
  • Loading branch information
DarkLight1337 and ywang96 authored Oct 6, 2024
1 parent f22619f commit b22b798
Show file tree
Hide file tree
Showing 12 changed files with 610 additions and 449 deletions.
60 changes: 56 additions & 4 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it.

----
Text-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^

Text Generation
---------------

Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1
Expand Down Expand Up @@ -40,6 +42,11 @@ Decoder-only Language Models
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
-
- ✅︎
* - :code:`BartForConditionalGeneration`
- BART
- :code:`facebook/bart-base`, :code:`facebook/bart-large-cnn`, etc.
-
-
* - :code:`ChatGLMModel`
- ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
Expand Down Expand Up @@ -259,11 +266,55 @@ Decoder-only Language Models
.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.

.. _supported_vlms:
Text Embedding
--------------

.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1

* - Architecture
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Gemma2Model`
- Gemma2-based
- :code:`BAAI/bge-multilingual-gemma2`, etc.
-
- ✅︎
* - :code:`MistralModel`
- Mistral-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
-
- ✅︎

Reward Modeling
---------------

.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1

* - Architecture
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
-
- ✅︎

.. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.

Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. _supported_vlms:

.. list-table::
:widths: 25 25 25 25 5 5
:header-rows: 1
Expand Down Expand Up @@ -378,6 +429,7 @@ Multimodal Language Models
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630

----

If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
Expand Down
7 changes: 3 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ Using VLMs
vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`.
This document shows you how to run and serve these models using vLLM.

.. important::
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.

We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
.. note::
We are actively iterating on VLM support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.

Offline Inference
-----------------
Expand Down
146 changes: 114 additions & 32 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import os
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
from typing import List, Literal, NamedTuple, Optional

import pytest

Expand Down Expand Up @@ -97,22 +97,23 @@ def iter_params(self, model_name: str):
self.trust_remote_code, self.tokenizer_mode)


# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

# yapf: disable
GENERATION_MODEL_SETTINGS = {
# [DETAILED TESTS]
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
# [FAST TESTS]
# Uses Llama
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
# TODO: Test on larger GPU
# "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"bigscience/bloomz-1b1": PPTestSettings.fast(),
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
# TODO: Test on larger GPU
# "databricks/dbrx-instruct": PPTestSettings.fast(),
"databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
Expand Down Expand Up @@ -161,8 +162,9 @@ def iter_params(self, model_name: str):

EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
# [FAST TESTS]
# Uses Llama
# "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
}

MULTIMODAL_MODEL_SETTINGS = {
Expand Down Expand Up @@ -192,40 +194,35 @@ def iter_params(self, model_name: str):
}
# yapf: enable

MODEL_SETTINGS = {
**GENERATION_MODEL_SETTINGS,
**EMBEDDING_MODEL_SETTINGS,
**MULTIMODAL_MODEL_SETTINGS,
}

# You can update this on your local machine to run specific tests
# NOTE: You can update this on your local machine to run specific tests
TEST_MODELS = [
# [LANGUAGE GENERATION]
"meta-llama/Meta-Llama-3-8B",
"facebook/chameleon-7b",
"ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
# [MULTIMODAL GENERATION]
"OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct",
"mistralai/Pixtral-12B-2409",
"fixie-ai/ultravox-v0_3",
]


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
distributed_backend: str, trust_remote_code: bool,
tokenizer_mode: Optional[str], num_gpus_available):
def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available: int,
*,
method: Literal["generate", "encode"] = "encode",
):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup

if num_gpus_available < tp_size:
pytest.skip(f"Need at least {tp_size} GPUs to run the test")
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
Expand Down Expand Up @@ -286,10 +283,95 @@ def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
]

try:
compare_two_settings(model_name, pp_args, tp_args, pp_env)
compare_two_settings(model_name,
pp_args,
tp_args,
pp_env,
method=method)
except Exception:
if pp_env is None:
raise
else:
# Ray ADAG tests are flaky, so we don't want to fail the test
logger.exception("Ray ADAG tests failed")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="generate")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="encode")


@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="generate")
Loading

0 comments on commit b22b798

Please sign in to comment.