Skip to content

Commit

Permalink
Merge branch 'main' into lwilkinson/machete-end2end
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Sep 21, 2024
2 parents a019473 + ec4aaad commit 306b283
Show file tree
Hide file tree
Showing 64 changed files with 963 additions and 430 deletions.
1 change: 1 addition & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ if [[ $commands == *" entrypoints/openai "* ]]; then
--ignore=entrypoints/openai/test_accuracy.py \
--ignore=entrypoints/openai/test_audio.py \
--ignore=entrypoints/openai/test_encoder_decoder.py \
--ignore=entrypoints/openai/test_embedding.py \
--ignore=entrypoints/openai/test_oot_registration.py "}
fi

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt
export MAX_JOBS=1
# Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real"
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# vllm commit id, generated by setup.py
vllm/commit_id.py

# vllm-flash-attn built from source
vllm/vllm_flash_attn/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand All @@ -12,6 +15,8 @@ __pycache__/
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
Expand Down
98 changes: 73 additions & 25 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
cmake_minimum_required(VERSION 3.26)

# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
project(vllm_extensions LANGUAGES CXX)

# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
Expand All @@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
Expand Down Expand Up @@ -70,19 +84,6 @@ endif()
find_package(Torch REQUIRED)

#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")

# Define _core_C extension
Expand All @@ -100,8 +101,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)

add_dependencies(default _core_C)

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
Expand Down Expand Up @@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

include(FetchContent)

#
# Define other extension targets
#
Expand All @@ -190,7 +191,6 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare(
cutlass
Expand Down Expand Up @@ -284,6 +284,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
csrc/quantization/machete/machete_pytorch.cu)
endif()

message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down Expand Up @@ -314,6 +315,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_moe_ops.cu")
endif()

message(STATUS "Enabling moe extension.")
define_gpu_extension_target(
_moe_C
DESTINATION vllm
Expand All @@ -324,7 +326,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)


if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
Expand All @@ -344,16 +345,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()

# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return()
endif ()

if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.

message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling rocm extension.")
add_dependencies(default _rocm_C)
if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
)
endif()

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)

# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE ####################

#################### WHEEL BUILD IMAGE ####################
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile.neuron
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# default base image
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.19.1-ubuntu20.04"
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04"

FROM $BASE_IMAGE

Expand All @@ -20,7 +20,7 @@ RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U

COPY ./vllm /app/vllm/vllm
COPY ./setup.py /app/vllm/setup.py
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand All @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()

install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()
9 changes: 7 additions & 2 deletions collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,14 @@ def summarize_vllm_build_flags():


def get_gpu_topo(run_lambda):
output = None

if get_platform() == 'linux':
return run_and_read_all(run_lambda, 'nvidia-smi topo -m')
return None
output = run_and_read_all(run_lambda, 'nvidia-smi topo -m')
if output is None:
output = run_and_read_all(run_lambda, 'rocm-smi --showtopo')

return output


# example outputs of CPU infos
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Option 2: Build from source

For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`.

Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guild in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_


1. Install `Triton flash attention for ROCm <https://github.com/ROCm/triton>`_
Expand All @@ -104,7 +104,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases.
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
$ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation
.. tip::
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/neuron-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
Installation with Neuron
========================

vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK.
At the moment Paged Attention is not supported in Neuron SDK, but naive continuous batching is supported in transformers-neuronx.
vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching.
Paged Attention and Chunked Prefill are currently in development and will be available soon.
Data types currently supported in Neuron SDK are FP16 and BF16.

Requirements
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ vLLM is flexible and easy to use with:
* Tensor parallelism and pipeline parallelism support for distributed inference
* Streaming outputs
* OpenAI-compatible API server
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
* Prefix caching support
* Multi-lora support

Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Decoder-only Language Models
- Nemotron-3, Nemotron-4, Minitron
- :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc.
- ✅︎
* - :code:`OLMoEForCausalLM`
- OLMoE
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
-
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
1 change: 0 additions & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ torch == 2.4.0
# These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
4 changes: 2 additions & 2 deletions requirements-neuron.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
-r requirements-common.txt

# Dependencies for Neuron devices
transformers-neuronx >= 0.9.0
torch-neuronx >= 2.1.0
transformers-neuronx >= 0.12.0
torch-neuronx >= 2.1.2
neuronx-cc
Loading

0 comments on commit 306b283

Please sign in to comment.