From 7046e4bf492fe5f77cb57ce181cb6c1c2301220d Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 4 Dec 2024 22:43:24 +0000 Subject: [PATCH] PR comments Signed-off-by: luka --- tests/kernels/test_fused_quant_layernorm.py | 4 ++-- vllm/_custom_ops.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index 15015063658ab..3997f4e9b8fe9 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -9,8 +9,8 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192, - 8193] # Arbitrary values for testing +HIDDEN_SIZES = [1, 3, 4, 16, 64, 2048, 5120, + 5137] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3808fb9a87e56..8d5dfebc4c03b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -22,7 +22,6 @@ supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - supports_moe_ops = True # neuron has torch version that doesn't even have impl_abstract @@ -242,6 +241,7 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, paged_kv_indptr: torch.Tensor, paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor) -> None: + return torch.ops._C.advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, @@ -737,7 +737,7 @@ def scaled_fp8_quant( shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + if current_platform.is_rocm() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -1020,9 +1020,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]], # the case when users use `import __annotations__` to turn type # hints into strings. if isinstance(v, fn_type) \ - and v.__code__.co_filename == __file__ \ - and any(arg is torch.Tensor or arg == "torch.Tensor" - for arg in v.__annotations__.values()): + and v.__code__.co_filename == __file__ \ + and any(arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values()): names_and_values_to_update[k] = hint_on_error(v) names_and_values.update(names_and_values_to_update)