From d9f86a9aac3f84c8b20b0f48931d8fda21ea3ade Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Oct 2024 20:51:20 -0400 Subject: [PATCH] Try to handle older versions of pytorch --- vllm/_custom_ops.py | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 24e008dc3802..d5137fcfc6d4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -25,6 +25,11 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True +import torch.library +try: + import torch.library.register_fake +except ImportError: + from torch.library import impl_abstract as register_fake def hint_on_error(fn): @@ -266,7 +271,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @torch.library.register_fake("_C::gptq_gemm") + @register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -301,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @torch.library.register_fake("_C::gptq_marlin_24_gemm") + @register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -309,7 +314,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::gptq_marlin_gemm") + @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -326,12 +331,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::ggml_dequantize") + @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, n: int) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_vec_a8") + @register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -340,7 +345,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_a8") + @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -350,7 +355,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::marlin_qqq_gemm") + @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -360,7 +365,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::marlin_gemm") + @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, @@ -369,7 +374,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::awq_dequantize") + @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: @@ -380,7 +385,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @torch.library.register_fake("_C::awq_gemm") + @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: @@ -389,7 +394,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @torch.library.register_fake("_C::aqlm_gemm") + @register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -405,7 +410,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @torch.library.register_fake("_C::aqlm_dequant") + @register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -415,14 +420,14 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @torch.library.register_fake("_C::fp8_marlin_gemm") + @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @torch.library.register_fake("_C::machete_gemm") + @register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, # Should be the tensor returned by machete_prepack_B @@ -440,13 +445,13 @@ def machete_gemm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::machete_prepack_B") + @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @torch.library.register_fake("_C::causal_conv1d_fwd") + @register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor], @@ -456,7 +461,7 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, silu_activation: bool) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::causal_conv1d_update") + @register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, @@ -464,7 +469,7 @@ def causal_conv1d_update_fake( conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::selective_scan_fwd") + @register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], @@ -639,7 +644,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "permute_cols"): - @torch.library.register_fake("_C::permute_cols") + @register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -837,7 +842,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @torch.library.register_fake("_moe_C::marlin_gemm_moe") + @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor,