Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync with upstream@v0.4.3-60-gbaa15a9e #47

Merged
merged 30 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4efff03
Bugfix: fix broken of download models from modelscope (#5233)
liuyhwangyh Jun 6, 2024
abe855d
[Kernel] Retune Mixtral 8x22b configs for FP8 on H100 (#5294)
pcmoritz Jun 6, 2024
828da0d
[Frontend] enable passing multiple LoRA adapters at once to generate(…
mgoldey Jun 6, 2024
a31cab7
[Core] Avoid copying prompt/output tokens if no penalties are used (#…
Yard1 Jun 7, 2024
ccdc490
[Core] Change LoRA embedding sharding to support loading methods (#5038)
Yard1 Jun 7, 2024
1506374
[Misc] Missing error message for custom ops import (#5282)
DamonFool Jun 7, 2024
baa15a9
[Feature][Frontend]: Add support for `stream_options` in `ChatComplet…
Etelis Jun 7, 2024
388596c
[Misc][Utils] allow get_open_port to be called for multiple times (#5…
youkaichao Jun 7, 2024
8d75fe4
[Kernel] Switch fp8 layers to use the CUTLASS kernels (#5183)
tlrmchlsmth Jun 7, 2024
18a277b
Remove Ray health check (#4693)
Yard1 Jun 7, 2024
dc49fb8
Addition of lacked ignored_seq_groups in _schedule_chunked_prefill (#…
JamesLim-sy Jun 7, 2024
ca3ea51
[Kernel] Dynamic Per-Token Activation Quantization (#5037)
dsikka Jun 7, 2024
7a9cb29
[Frontend] Add OpenAI Vision API Support (#5237)
ywang96 Jun 7, 2024
6840a71
[Misc] Remove unused cuda_utils.h in CPU backend (#5345)
DamonFool Jun 7, 2024
767c727
fix DbrxFusedNormAttention missing cache_config (#5340)
Calvinnncy97 Jun 7, 2024
e69ded7
[Bug Fix] Fix the support check for FP8 CUTLASS (#5352)
cli99 Jun 8, 2024
b3376e5
[Misc] Add args for selecting distributed executor to benchmarks (#5335)
BKitor Jun 8, 2024
c96fc06
[ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965)
hongxiayang Jun 8, 2024
9fb900f
[CI/Test] improve robustness of test (hf_runner) (#5347)
youkaichao Jun 8, 2024
8ea5e44
[CI/Test] improve robustness of test (vllm_runner) (#5357)
youkaichao Jun 8, 2024
c09dade
[Misc][Breaking] Change FP8 checkpoint format from act_scale -> input…
mgoin Jun 8, 2024
0373e18
[Core][CUDA Graph] add output buffer for cudagraph (#5074)
youkaichao Jun 9, 2024
5d7e3d0
[mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)
youkaichao Jun 9, 2024
5467ac3
[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custo…
bnellnm Jun 9, 2024
45f92c0
[Bugfix] Fix KeyError: 1 When Using LoRA adapters (#5164)
BlackBird-Coding Jun 9, 2024
5884c2b
[Misc] Update to comply with the new `compressed-tensors` config (#5350)
dsikka Jun 10, 2024
68bc817
[Frontend][Misc] Enforce Pixel Values as Input Type for VLMs in API S…
ywang96 Jun 10, 2024
c81da5f
[misc][typo] fix typo (#5372)
youkaichao Jun 10, 2024
0bfa1c4
[Misc] Improve error message when LoRA parsing fails (#5194)
DarkLight1337 Jun 10, 2024
6b29d6f
[Model] Initial support for LLaVA-NeXT (#4199)
DarkLight1337 Jun 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py

- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
Expand Down Expand Up @@ -138,14 +139,7 @@ steps:
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- pytest -v -s -x lora/test_long_context.py

- label: Tensorizer Test
#mirror_hardwares: [amd]
Expand Down
22 changes: 6 additions & 16 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,6 @@ endif()
#
find_package(Torch REQUIRED)

#
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
# `libtorch_python.so` for linking against an extension. Torch's cmake
# configuration does not include this library (presumably since the cmake
# config is used for standalone C++ binaries that link against torch).
# The `libtorch_python.so` library defines some of the glue code between
# torch/python via pybind and is required by VLLM extensions for this
# reason. So, add it by manually with `find_library` using torch's
# installed library path.
#
find_library(torch_python_LIBRARY torch_python PATHS
"${TORCH_INSTALL_PREFIX}/lib")

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
Expand Down Expand Up @@ -171,7 +158,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
Expand Down Expand Up @@ -218,14 +205,15 @@ define_gpu_extension_target(
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

#
# _moe_C extension
#

set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu")

define_gpu_extension_target(
Expand All @@ -235,6 +223,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

#
Expand All @@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp")
"csrc/punica/torch_bindings.cpp")

#
# Copy GPU compilation flags+update for punica
Expand Down Expand Up @@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI)
else()
message(WARNING "Unable to create _punica_C target because none of the "
Expand Down
6 changes: 3 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
&& cd ..


Expand Down
10 changes: 9 additions & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def main(args: argparse.Namespace):
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization)
gpu_memory_utilization=args.gpu_memory_utilization,
distributed_executor_backend=args.distributed_executor_backend)

sampling_params = SamplingParams(
n=args.n,
Expand Down Expand Up @@ -221,5 +222,12 @@ def run_to_completion(profile_dir: Optional[str] = None):
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
args = parser.parse_args()
main(args)
13 changes: 11 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run_vllm(
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
Expand All @@ -100,6 +101,7 @@ def run_vllm(
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -225,8 +227,8 @@ def main(args: argparse.Namespace):
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -368,6 +370,13 @@ def main(args: argparse.Namespace):
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
12 changes: 6 additions & 6 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc")
#
# Check the compile flags
#
list(APPEND CXX_COMPILE_FLAGS
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")

Expand Down Expand Up @@ -44,8 +44,8 @@ if (AVX512_FOUND)

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
else()
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
Expand Down Expand Up @@ -73,18 +73,18 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pybind.cpp")
"csrc/cpu/torch_bindings.cpp")

define_gpu_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
WITH_SOABI
USE_SABI 3
WITH_SOABI
)

add_custom_target(default)
message(STATUS "Enabling C extension.")
add_dependencies(default _C)

11 changes: 8 additions & 3 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module)
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
Expand Down Expand Up @@ -294,14 +294,15 @@ endmacro()
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
#
# Note: optimization level/debug info is set via cmake build type.
#
function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE"
"DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")

# Add hipify preprocessing step when building with HIP/ROCm.
Expand All @@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
set(GPU_WITH_SOABI)
endif()

Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()

if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.
Expand Down
2 changes: 1 addition & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>
Expand Down
34 changes: 18 additions & 16 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* limitations under the License.
*/

#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
Expand Down Expand Up @@ -808,16 +808,17 @@ void paged_attention_v1(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);

DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
Expand Down Expand Up @@ -972,16 +973,17 @@ void paged_attention_v2(
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size, int max_seq_len,
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
Expand All @@ -990,4 +992,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
14 changes: 9 additions & 5 deletions csrc/cache.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#pragma once

#include <torch/extension.h>
#include <torch/all.h>

#include <map>
#include <vector>

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);

void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const float kv_scale);
const std::string& kv_cache_dtype,
const double kv_scale);

void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
Expand All @@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float scale, const std::string& kv_cache_dtype);
const double scale, const std::string& kv_cache_dtype);
13 changes: 8 additions & 5 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

Expand Down Expand Up @@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,

} // namespace vllm

void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
Expand Down Expand Up @@ -255,7 +258,7 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) {
const std::string& kv_cache_dtype, const double kv_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
Expand Down Expand Up @@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,

// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) {
const double kv_scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
Expand Down
Loading