diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index ca192b1db6528..15bd5b6ed1564 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -3,6 +3,7 @@ #include "quantization/vectorization.cuh" #include +#include #ifndef USE_ROCM #include @@ -17,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz; // issue when running dynamic quantization. Here use 224.0f for rocm. constexpr auto FP8_E4M3_MAX = 224.0f; #endif +constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; namespace vllm { diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 7e2c8f9f83a7e..3c4f183bf4b59 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -144,11 +144,11 @@ void rms_norm_dynamic_per_token_quant( torch::Tensor& scales, // [num_tokens] double const var_epsilon, // Variance epsilon used in norm calculation std::optional scale_ub, std::optional residual) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || - out.dtype() == torch::kInt8); + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); if (scale_ub.has_value()) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(out.dtype() == kFp8Type); } TORCH_CHECK(scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index f729fe58c0c08..cec6b54edb569 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -20,7 +20,7 @@ template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; @@ -53,7 +53,8 @@ __device__ void compute_dynamic_per_token_scales( float const rms, float const* __restrict__ scale_ub, float const min_scaling_factor, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; constexpr scalar_out_t qmax{std::numeric_limits::max()}; float block_absmax_val_maybe = 0.0f; @@ -99,7 +100,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, float const rms, float const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -123,7 +125,7 @@ template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. vec4_t const* vec_input = @@ -184,7 +186,8 @@ __device__ void compute_dynamic_per_token_scales( float const rms, float const* __restrict__ scale_ub, float const min_scaling_factor, int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; // Vectorized input/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -263,7 +266,8 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, float const rms, float const scale, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * hidden_size; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index ff8e807ecb600..baf8d73fdbffb 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -9,10 +9,15 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120, - 5137] # Arbitrary values for testing -HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +VEC_HIDDEN_SIZES = range(1024, 1030) +# Avoid combinatorial explosion with full Cartesian product +NUM_TOKENS_HIDDEN_SIZES = [ + *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], + *[(83, i) for i in [1, 1033, 2048, 5120]], + *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]], + *[(4096, i) for i in [1, 64, 5137]], +] + ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] SEEDS = [0] @@ -100,8 +105,7 @@ def ops_impl(weight: torch.Tensor, scale_ub) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES)