From 0d38f0a0d6163a3af2acb124644bad75a64c1836 Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Tue, 17 Dec 2024 21:32:20 +0000 Subject: [PATCH] Created the entry + arch structure for the compressor and ignore 2to4 tests for >90 sm capability --- CMakeLists.txt | 7 ++-- csrc/cutlass_extensions/common.cpp | 11 +++++ csrc/cutlass_extensions/common.hpp | 2 + csrc/quantization/cutlass_w8a8/common.hpp | 27 ------------ .../cutlass_w8a8/scaled_mm_entry.cu | 12 +----- ...compressor.cu => sparse_compressor_c3x.cu} | 2 +- .../sparse/cutlass/sparse_compressor_entry.cu | 42 +++++++++++++++++++ csrc/sparse/cutlass/sparse_scaled_mm_entry.cu | 14 ++----- tests/kernels/test_semi_structured.py | 3 ++ tests/quantization/test_compressed_tensors.py | 9 ++++ tests/weight_loading/models.txt | 4 +- .../run_model_weight_loading_test.sh | 4 ++ tests/weight_loading/test_weight_loading.py | 7 ++++ 13 files changed, 90 insertions(+), 54 deletions(-) create mode 100644 csrc/cutlass_extensions/common.cpp delete mode 100644 csrc/quantization/cutlass_w8a8/common.hpp rename csrc/sparse/cutlass/{sparse_compressor.cu => sparse_compressor_c3x.cu} (98%) create mode 100644 csrc/sparse/cutlass/sparse_compressor_entry.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 38166e9b7f7e9..755b9d2a951dd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,7 +199,8 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" - "csrc/torch_bindings.cpp") + "csrc/torch_bindings.cpp" + "csrc/cutlass_extensions/common.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") @@ -242,7 +243,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" - "csrc/sparse/cutlass/sparse_compressor.cu") + "csrc/sparse/cutlass/sparse_compressor_entry.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -278,7 +279,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - "csrc/sparse/cutlass/sparse_compressor.cu" + "csrc/sparse/cutlass/sparse_compressor_c3x.cu" "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 11c8486647c5e..a18ebdd8b84fc 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -30,3 +30,5 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { device); return max_shared_mem_per_block_opt_in; } + +int32_t get_sm_version_num(); diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..4f7b6588ef3f7 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } -int32_t get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/sparse/cutlass/sparse_compressor.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu similarity index 98% rename from csrc/sparse/cutlass/sparse_compressor.cu rename to csrc/sparse/cutlass/sparse_compressor_c3x.cu index 3e31e19d52e1f..7367553c73c42 100644 --- a/csrc/sparse/cutlass/sparse_compressor.cu +++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu @@ -146,7 +146,7 @@ bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta, return true; } -bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, torch::Tensor const& a) { if (a.dtype() == torch::kBFloat16) { return cutlass_sparse_compress(a_nzs, a_meta, diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu new file mode 100644 index 0000000000000..532c35faae7dc --- /dev/null +++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu @@ -0,0 +1,42 @@ +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X +bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a); +#endif + +bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, + torch::Tensor const& a) { + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); + TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && + a_nzs.size(1) * 2 == a.size(1) && + a_meta.size(1) * 2 * 4 == a.size(1)); + // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && + a_meta.stride(1) == 1); // Row-major + TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); + + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { + return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_sparse_mm for a compute capability less than " + "CUDA device capability: ", + version_num); +} diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu index 6e067e6529cf9..4c930b603c9e4 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu @@ -3,6 +3,8 @@ #include #include +#include "cutlass_extensions/common.hpp" + #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, @@ -12,16 +14,6 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, c10::optional const& bias); #endif -int32_t test_get_sm_version_num() { - int32_t major_capability, minor_capability; - cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, - 0); - cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, - 0); - int32_t version_num = major_capability * 10 + minor_capability; - return version_num; -} - void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& bt_nzs, torch::Tensor const& bt_meta, @@ -48,7 +40,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, } at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); - int32_t version_num = test_get_sm_version_num(); + int32_t version_num = get_sm_version_num(); // Guard against compilation issues for sm90 kernels #if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X diff --git a/tests/kernels/test_semi_structured.py b/tests/kernels/test_semi_structured.py index dd9f444ed504f..34244a8fe4ca7 100644 --- a/tests/kernels/test_semi_structured.py +++ b/tests/kernels/test_semi_structured.py @@ -4,6 +4,7 @@ """ from typing import Optional, Tuple, Type +import pytest import torch from vllm import _custom_ops as ops @@ -101,6 +102,8 @@ def baseline_scaled_mm(a: torch.Tensor, return output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") # Test working with a subset of A and B for sparse matmul def test_cutlass_sparse_subset(): big_m = 1024 diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 557d72172f346..21fec990aa873 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -14,6 +14,7 @@ CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.platforms import current_platform @pytest.mark.parametrize( @@ -211,6 +212,8 @@ def test_compressed_tensors_kv_cache(vllm_runner): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -224,6 +227,8 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy): assert sparsity_map.get("Linear").sparsity_structure == "2:4" +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ ("nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing", "channel", "token"), @@ -249,6 +254,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize("args_2of4", [ ("nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing", "channel", "token"), @@ -272,6 +279,8 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): assert output +@pytest.mark.skipif(not current_platform.has_device_capability(90), + reason="Sparse FP8 is not yet supported on this GPU type.") @pytest.mark.parametrize( "args_2of4", [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")]) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 9363a5fef0e0f..a06956ce18a93 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,8 +21,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main -compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main -compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main +compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90 +compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90 awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index a4d0c44c22b51..693128640e07d 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -26,6 +26,10 @@ do export QUANTIZATION=${array[0]} export MODEL_NAME=${array[1]} export REVISION=${array[2]} + # If array length is larger than 3, then MIN_CAPABILITY is provided + if [ ${#array[@]} -gt 3 ]; then + export MIN_CAPABILITY=${array[3]} + fi pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$? if [[ $LOCAL_SUCCESS == 0 ]]; then diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index d8bca05e204c0..199731bdc21fe 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -1,14 +1,21 @@ import os +import pytest import torch +from vllm.platforms import current_platform + MAX_MODEL_LEN = 1024 MODEL_NAME = os.environ.get("MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq") REVISION = os.environ.get("REVISION", "main") QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") +MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89") +@pytest.mark.skipif( + not current_platform.has_device_capability(int(MIN_CAPABILITY)), + reason="Current system does not have minimum capability.") def test_weight_loading(vllm_runner): """ Test parameter weight loading with tp>1.