diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index 49e8a1bdfe2ef..2f7ccee5ddb36 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -52,11 +52,6 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 - # # Initialize a to all ones - # a = torch.ones((m, k), device='cuda') - # # Initialize b to all ones - # b = torch.ones((n, k), device='cuda') - b = prune_to_2_4(b.t()).t() if dtype == torch.int8: diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu index 53a9804d69ff2..a4f5eaed4134f 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu @@ -297,7 +297,8 @@ void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out, } } -void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, +void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, + torch::Tensor const& a, torch::Tensor const& e, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -306,36 +307,35 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (bias) { - TORCH_CHECK(bias->dtype() == c.dtype(), - "currently bias dtype must match output dtype ", c.dtype()); + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); return cutlass_scaled_sparse_mm_sm90_epilogue( - c, a, e, b, a_scales, b_scales, *bias); + out, a, e, b, a_scales, b_scales, *bias); } else { return cutlass_scaled_sparse_mm_sm90_epilogue( - c, a, e, b, a_scales, b_scales); + out, a, e, b, a_scales, b_scales); } } -// void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out, torch::Tensor -// const& a, -// torch::Tensor const& e, -// torch::Tensor const& b, -// torch::Tensor const& a_scales, -// torch::Tensor const& b_scales, -// torch::Tensor const& azp_adj, -// c10::optional const& azp, -// c10::optional const& bias) { -// TORCH_CHECK(a_scales.dtype() == torch::kFloat32); -// TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - -// if (azp) { -// return -// cutlass_scaled_sparse_mm_sm90_epilogue( -// out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias); -// } else { -// return cutlass_scaled_sparse_mm_sm90_epilogue( -// out, a, e, b, a_scales, b_scales, azp_adj, bias); -// } -// } +void cutlass_scaled_sparse_mm_azp_sm90(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& e, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (azp) { + return cutlass_scaled_sparse_mm_sm90_epilogue( + out, a, e, b, a_scales, b_scales, azp_adj, *azp, bias); + } else { + return cutlass_scaled_sparse_mm_sm90_epilogue( + out, a, e, b, a_scales, b_scales, azp_adj, bias); + } +} #endif diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh index 76bd3fadd90a1..e66c24627e067 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh +++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh @@ -364,8 +364,6 @@ struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = AccType; - // typename std::conditional, int32_t, - // float>::type; using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< @@ -432,9 +430,6 @@ void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(1); - // using StrideB = Stride, int64_t>; - // using StrideC = typename Gemm::StrideC; - using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; using StrideB = typename Gemm::GemmKernel::StrideB; diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu index 9c2aed2eb3079..8a92d3a598964 100644 --- a/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu +++ b/csrc/sparse/cutlass/sparse_scaled_mm_entry.cu @@ -52,8 +52,8 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a, // Check for strides and alignment TORCH_CHECK(a.stride(1) == 1); // Row-major - // TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major - // TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(b.stride(0) == 1 && c.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 12e1b26d87081..5cd0059a4df89 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -163,11 +163,11 @@ def apply_weights(self, input_scale = layer.input_scale q_input = x - out = ops.cutlass_scaled_sparse_mm(a=layer.weight, + out = ops.cutlass_scaled_sparse_mm(a=q_input, + b=layer.weight, e=layer.meta, - b=q_input.t(), - scale_a=layer.weight_scale, - scale_b=input_scale, + scale_a=input_scale, + scale_b=layer.weight_scale, out_dtype=self.output_dtype, bias=bias) assert out.is_contiguous()