Skip to content

Commit

Permalink
Try to handle older versions of pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm committed Oct 5, 2024
1 parent 663874e commit d9f86a9
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
import vllm._moe_C # noqa: F401
supports_moe_ops = True

import torch.library
try:
import torch.library.register_fake
except ImportError:
from torch.library import impl_abstract as register_fake

def hint_on_error(fn):

Expand Down Expand Up @@ -266,7 +271,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_gemm"):

@torch.library.register_fake("_C::gptq_gemm")
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
Expand Down Expand Up @@ -301,15 +306,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):

@torch.library.register_fake("_C::gptq_marlin_24_gemm")
@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::gptq_marlin_gemm")
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -326,12 +331,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::ggml_dequantize")
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -340,7 +345,7 @@ def _ggml_mul_mat_vec_a8_fake(
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_a8")
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -350,7 +355,7 @@ def _ggml_mul_mat_a8_fake(
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::marlin_qqq_gemm")
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
Expand All @@ -360,7 +365,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::marlin_gemm")
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
Expand All @@ -369,7 +374,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::awq_dequantize")
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
Expand All @@ -380,7 +385,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
dtype=scales.dtype,
device=scales.device)

@torch.library.register_fake("_C::awq_gemm")
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
Expand All @@ -389,7 +394,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
dtype=input.dtype,
device=input.device).sum(0)

@torch.library.register_fake("_C::aqlm_gemm")
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
Expand All @@ -405,7 +410,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))

@torch.library.register_fake("_C::aqlm_dequant")
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
Expand All @@ -415,14 +420,14 @@ def _aqlm_dequant_fake(
dtype=codebooks.dtype,
device=codebooks.device)

@torch.library.register_fake("_C::fp8_marlin_gemm")
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

@torch.library.register_fake("_C::machete_gemm")
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
Expand All @@ -440,13 +445,13 @@ def machete_gemm_fake(
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::machete_prepack_B")
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)

@torch.library.register_fake("_C::causal_conv1d_fwd")
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
Expand All @@ -456,15 +461,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::causal_conv1d_update")
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::selective_scan_fwd")
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
Expand Down Expand Up @@ -639,7 +644,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "permute_cols"):

@torch.library.register_fake("_C::permute_cols")
@register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
Expand Down Expand Up @@ -837,7 +842,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

@torch.library.register_fake("_moe_C::marlin_gemm_moe")
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down

0 comments on commit d9f86a9

Please sign in to comment.