From db1f07e8639badced65c6b85f812567f83442a74 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 11:03:53 -0400 Subject: [PATCH 01/77] GPTQ Fused MoE class --- .../layers/fused_moe/__init__.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 155 +++++++++++++++++- 2 files changed, 156 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e9b5703ca28be..7f27e2660db65 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,11 +1,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE) from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "GPTQFusedMoE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3df0b61a9ebe4..9643642b9b53e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -498,4 +498,157 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight \ No newline at end of file + param_data[expert_id] = loaded_weight + + +class GPTQFusedMoE(torch.nn.Module): + """GPTQFusedMoE layer for GPTQ MoE models. + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size = intermediate_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + assert (not use_grouped_topk and num_expert_group is None + and topk_group is None) + + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + if ("_qweight" in weight_name or "_scales" in weight_name + or "_qzeros" in weight_name): + if "w13" in weight_name: + shard_size = loaded_weight.size()[-1] + if shard_id == "w1": + param.data[expert_id, :, :shard_size] = loaded_weight + elif shard_id == "w2" or shard_id == "w3": + param.data[expert_id, :, shard_size:] = loaded_weight + else: + raise ValueError(f"Invalid shard_id: {shard_id}: " + "must be w1, w2, or w3.") + elif "w2" in weight_name: + param.data[expert_id][:] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + elif "_g_idx" in weight_name: + if "w13" not in weight_name and "w2" not in weight_name: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + param.data[expert_id] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}.") + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None): + assert (not use_grouped_topk and topk_group is None + and num_expert_group is None) + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=False, + topk_group=False, + num_expert_group=False) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] From 6753789bbe7e636a51a7a2adca10a24968bf76f1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 12:41:52 -0400 Subject: [PATCH 02/77] Add GPTQMarlinMoEMethod to gptq_marlin.py --- .../layers/quantization/gptq_marlin.py | 304 +++++++++++++++++- 1 file changed, 289 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..1588b2a6113ad 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,25 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase, + GPTQFusedMoE) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -33,8 +40,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -109,11 +122,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, GPTQFusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,7 +195,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +316,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +326,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +348,259 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=None) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, + ) From 7df4014ce516363202cf3646a9c0598fb9cdeed8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 09:00:15 -0400 Subject: [PATCH 03/77] Use FusedMoE layer for all loads --- .../layers/fused_moe/__init__.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 172 ++---------------- .../layers/quantization/gptq_marlin.py | 5 +- 3 files changed, 22 insertions(+), 158 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 7f27e2660db65..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,12 +1,11 @@ from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "GPTQFusedMoE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9643642b9b53e..b0d7d4b538df3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -334,6 +334,25 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim + # GPTQ Values + if ("scales" in weight_name or "qweight" in weight_name + or "qzeros" in weight_name): + if (shard_id == "w1" or shard_id == "w3"): + shard_dim = 1 - shard_dim + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + if "g_idx" in weight_name: + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + # Case weight_scales if "weight_scale" in weight_name: # load the weight scaling based on the quantization scheme @@ -499,156 +518,3 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight - - -class GPTQFusedMoE(torch.nn.Module): - """GPTQFusedMoE layer for GPTQ MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / - w13) and RowParallelLinear weights (down_proj/ w2). - Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We - copy that naming convention here and handle any remapping in the - load_weights function in each model implementation. - Args: - num_experts: Number of experts in the model - top_k: Number of experts selected for each token - hidden_size: Input hidden state size of the transformer - intermediate_size: Intermediate size of the experts - params_dtype: Data type for the parameters. - reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel - quant_config: Quantization configure. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - ): - super().__init__() - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - self.top_k = top_k - self.num_experts = num_experts - self.intermediate_size = intermediate_size - self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results - self.renormalize = renormalize - assert (not use_grouped_topk and num_expert_group is None - and topk_group is None) - - if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedFusedMoEMethod() - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None - - self.quant_method.create_weights( - layer=self, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size_per_partition, - params_dtype=params_dtype, - weight_loader=self.weight_loader, - ) - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: - - if ("_qweight" in weight_name or "_scales" in weight_name - or "_qzeros" in weight_name): - if "w13" in weight_name: - shard_size = loaded_weight.size()[-1] - if shard_id == "w1": - param.data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == "w2" or shard_id == "w3": - param.data[expert_id, :, shard_size:] = loaded_weight - else: - raise ValueError(f"Invalid shard_id: {shard_id}: " - "must be w1, w2, or w3.") - elif "w2" in weight_name: - param.data[expert_id][:] = loaded_weight - else: - raise ValueError(f"Invalid weight name: {weight_name}: " - "must contain 'w13' or 'w2'.") - elif "_g_idx" in weight_name: - if "w13" not in weight_name and "w2" not in weight_name: - raise ValueError(f"Invalid weight name: {weight_name}: " - "must contain 'w13' or 'w2'.") - param.data[expert_id] = loaded_weight - else: - raise ValueError(f"Invalid weight name: {weight_name}.") - - @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None): - assert (not use_grouped_topk and topk_group is None - and num_expert_group is None) - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - - topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - assert self.quant_method is not None - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=False, - topk_group=False, - num_expert_group=False) - - if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states - - @classmethod - def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1588b2a6113ad..15530e692eb3d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -8,8 +8,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase, - GPTQFusedMoE) + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -128,7 +127,7 @@ def get_quant_method( if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) - elif isinstance(layer, GPTQFusedMoE): + elif isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) return None From 2fa03e5f5f0916ec8c36d446dcde526bf27d2b99 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 11:25:08 -0400 Subject: [PATCH 04/77] Make sure that GPTQ runs through mixtral.py --- vllm/model_executor/layers/quantization/gptq_marlin.py | 6 +++--- vllm/model_executor/model_loader/utils.py | 2 +- vllm/model_executor/models/mixtral.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 15530e692eb3d..fbf384ea34dc1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Parameter @@ -551,8 +551,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, - size_k=(layer.intermediate_size if self.quant_config.desc_act else - layer.intermediate_size_per_partition), + size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -575,6 +574,7 @@ def apply( use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 4bb943ab3afe4..d247e4cf3f07b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e744e36ac08bf..6413b56605ecf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -435,7 +435,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith("bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,6 +454,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if name.endswith("bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -464,7 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith("bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From 8a504d936aff8b3955f25ece553efb6366c52e3e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 12:40:52 -0400 Subject: [PATCH 05/77] enforce float16A/scales for marlin moe --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- vllm/model_executor/models/mixtral.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index fbf384ea34dc1..d52ff3131fdef 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -404,7 +404,7 @@ def create_weights( torch.empty(num_experts, scales_size13, 2 * intermediate_size, - dtype=params_dtype), + dtype=torch.half), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) @@ -414,7 +414,7 @@ def create_weights( torch.empty(num_experts, scales_size2, hidden_size, - dtype=params_dtype), + dtype=torch.half), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 6413b56605ecf..148ef393277e4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,11 +95,12 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape + orig_dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape) + final_hidden_states = self.experts(hidden_states.half(), router_logits) + return final_hidden_states.view(orig_shape).to(orig_dtype) class MixtralAttention(nn.Module): From ec47561fa40ddb9146a3c8b694c5f88016052652 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 13:13:32 -0400 Subject: [PATCH 06/77] cleanup --- vllm/model_executor/layers/quantization/gptq_marlin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d52ff3131fdef..11012a326b045 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -598,7 +598,6 @@ def apply( layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - renormalize=renormalize, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, num_bits=self.quant_config.quant_type.size_bits, From 2ad2e5608eeede10683412bbbfaf30b3a68019dc Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 4 Sep 2024 11:53:25 -0700 Subject: [PATCH 07/77] [MISC] Consolidate FP8 kv-cache tests (#8131) --- .buildkite/run-cpu-test.sh | 7 +- .../basic_correctness/test_chunked_prefill.py | 43 +---- tests/models/test_fp8.py | 181 ++++++++---------- tests/models/test_fp8kv_flashinfer.py | 96 ---------- 4 files changed, 94 insertions(+), 233 deletions(-) delete mode 100644 tests/models/test_fp8kv_flashinfer.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 8e4be08f3aba0..ca9cf15780e25 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,7 +23,12 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \ + --ignore=tests/models/test_oot_registration.py \ + --ignore=tests/models/test_registry.py \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/test_jamba.py \ + --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index a63ac380e8598..9c34b2a13fd53 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -16,18 +16,6 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] -E5M2_KV_MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-chat-hf", -] -E4M3_KV_MODELS = [ - "meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" -] -KV_CACHE_QUANTIZATION_PATHS = { - "meta-llama/Llama-2-7b-chat-hf": - "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json" -} @pytest.mark.parametrize("model", MODELS) @@ -78,10 +66,10 @@ def test_models( ) -@pytest.mark.parametrize("kv_cache_dtype,model", - [("fp8_e5m2", m) - for m in E5M2_KV_MODELS] + [("fp8_e4m3", m) - for m in E4M3_KV_MODELS]) +@pytest.mark.parametrize( + "kv_cache_dtype,model", + [("fp8_e4m3", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) @@ -104,30 +92,15 @@ def test_models_with_fp8_kv_cache( disable_async_output_proc: bool, ) -> None: """ - Only checks log probs match between chunked-prefill and - non-chunked-prefill version of vLLM model runner. - - This test is used when there is discrepancy in kernels - / numerics (e.g. when using lower-precision types like FP8). + Check output logprobs match between no_chunked_prefill and chunked_prefill + with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py, + so here we only check chunked prefill. """ NUM_LOG_PROBS = 8 - if model == "facebook/opt-125m": - pytest.skip( - "#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m" - ) - if ((model, kv_cache_dtype, chunked_prefill_token_size) == ( - "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", "fp8_e4m3", 4)): - pytest.skip("flakey test, see: #7874 #8051") - max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - extra_kwargs = {} - if model in KV_CACHE_QUANTIZATION_PATHS: - extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[ - model] - with vllm_runner( model, tensor_parallel_size=tensor_parallel_size, @@ -135,7 +108,6 @@ def test_models_with_fp8_kv_cache( max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -149,7 +121,6 @@ def test_models_with_fp8_kv_cache( max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 4ab968c01da04..17acdb52322fd 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -3,116 +3,97 @@ Note: these tests will only pass on L4 GPU. """ import os -from typing import List +from typing import Optional import pytest -import torch -from transformers import AutoTokenizer +from tests.kernels.utils import override_backend_env_variable from tests.quantization.utils import is_quant_method_supported -from vllm import LLM, SamplingParams -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -MAX_MODEL_LEN = 1024 - -MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", - "meta-llama/Meta-Llama-3-8B-Instruct", -] +from ..models.utils import check_logprobs_close -EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { - "auto": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' - ], - "fp8": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system made up of several basic components that work together to enable it to', - 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' - ] - }, - "meta-llama/Meta-Llama-3-8B-Instruct": { - "auto": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' - ], - "fp8": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' - ] - }, -} +os.environ["TOKENIZERS_PARALLELISM"] = "true" -# This test compares against golden strings for exact match since -# there is no baseline implementation to compare against -# and is unstable w.r.t specifics of the fp8 implementation or -# the hardware being run on. -# Disabled to prevent it from breaking the build -@pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build.") @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") -@pytest.mark.parametrize("model_name", MODELS) -@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) -def test_models(example_prompts, model_name, kv_cache_dtype) -> None: - model = LLM(model=model_name, - max_model_len=MAX_MODEL_LEN, - trust_remote_code=True, - enforce_eager=True, - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) +@pytest.mark.parametrize( + "kv_cache_dtype,base_model,test_model,scale_path", + [ + # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. + ("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None), + # Test FP16 checkpoint w. fp8_e5m2 kv-cache. + ("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", None), + # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. + ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-7b-chat-hf", + "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + ]) +# Due to low-precision numerical divergence, we only test logprob of 4 tokens +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +# Due to low-precision numerical divergence, this test is too sensitive for +# the async postprocessor +@pytest.mark.parametrize("disable_async_output_proc", [True]) +def test_models( + vllm_runner, + example_prompts, + kv_cache_dtype: str, + base_model: str, + test_model: str, + scale_path: Optional[str], + max_tokens: int, + enforce_eager: bool, + backend: str, + tensor_parallel_size: int, + disable_async_output_proc: bool, + monkeypatch, +) -> None: + """ + Only checks log probs match to cover the discrepancy in + numerical sensitive kernels. + """ + override_backend_env_variable(monkeypatch, backend) + + MAX_MODEL_LEN = 1024 + NUM_LOG_PROBS = 8 + + with vllm_runner( + base_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype="auto", + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) - tokenizer = AutoTokenizer.from_pretrained(model_name) - formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) - for prompt in example_prompts - ] + extra_kwargs = {} + if scale_path is not None: + extra_kwargs["quantization_param_path"] = scale_path - params = SamplingParams(max_tokens=20, temperature=0) - generations: List[str] = [] - # Note: these need to be run 1 at a time due to numerical precision, - # since the expected strs were generated this way. - for prompt in formatted_prompts: - outputs = model.generate(prompt, params) - generations.append(outputs[0].outputs[0].text) - del model + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + **extra_kwargs, + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) - print(model_name, kv_cache_dtype, generations) - expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] - for i in range(len(example_prompts)): - generated_str = generations[i] - expected_str = expected_strs[i] - assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="fp16_kv_cache", + name_1="fp8_kv_cache", + ) diff --git a/tests/models/test_fp8kv_flashinfer.py b/tests/models/test_fp8kv_flashinfer.py deleted file mode 100644 index ff2a44162b6c3..0000000000000 --- a/tests/models/test_fp8kv_flashinfer.py +++ /dev/null @@ -1,96 +0,0 @@ -# flake8: noqa -"""Tests fp8 models against ground truth generation -This verifies the flashinfer backend with fp8 -quantization and fp8 KV Cache without scaling -factors Note: these tests will only pass on H100 GPU. -""" -import os -from typing import List - -import pytest -from transformers import AutoTokenizer - -from tests.quantization.utils import is_quant_method_supported -from vllm import LLM, SamplingParams - -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -MAX_MODEL_LEN = 1024 - -MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8", -] - -EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8": { - "auto": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o', - ], - "fp8": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', - ] - } -} - - -# This test compares against golden strings for exact match since -# there is no baseline implementation to compare against -# and is unstable w.r.t specifics of the fp8 implementation or -# the hardware being run on. -# No assert to prevent it from breaking the build -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") -@pytest.mark.parametrize("model_name", MODELS) -@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) -@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"]) -def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None: - # Note that the golden strings may not work for FLASHINFER Backend. - # The intention is to test the path - os.environ["VLLM_ATTENTION_BACKEND"] = backend - model = LLM(model=model_name, - max_model_len=MAX_MODEL_LEN, - trust_remote_code=True, - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) - for prompt in example_prompts - ] - - params = SamplingParams(max_tokens=20, temperature=0) - generations: List[str] = [] - # Note: these need to be run 1 at a time due to numerical precision, - # since the expected strs were generated this way. - for prompt in formatted_prompts: - outputs = model.generate(prompt, params) - generations.append(outputs[0].outputs[0].text) - del model - - print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}") - expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] - for i in range(len(example_prompts)): - generated_str = generations[i] - expected_str = expected_strs[i] - print(f"generated_str\n: {generated_str}") - print(f"expected_str\n: {expected_str}") From d1dec6424307a6070bf3ab1700633996f20ef248 Mon Sep 17 00:00:00 2001 From: alexeykondrat <143633163+alexeykondrat@users.noreply.github.com> Date: Wed, 4 Sep 2024 14:57:54 -0400 Subject: [PATCH 08/77] [CI/Build][ROCm] Enabling LoRA tests on ROCm (#7369) Co-authored-by: Simon Mo --- .buildkite/run-amd-test.sh | 47 +++++++++++++++++++++++++++++----- .buildkite/test-pipeline.yaml | 3 +-- tests/lora/test_gemma.py | 4 +++ tests/lora/test_quant_model.py | 24 ++++++++++++----- 4 files changed, 64 insertions(+), 14 deletions(-) mode change 100644 => 100755 .buildkite/run-amd-test.sh diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh old mode 100644 new mode 100755 index 5548071390aff..972c62a091aea --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,5 +1,5 @@ # This script runs test inside the corresponding ROCm docker container. -set -ex +set -o pipefail # Print ROCm version echo "--- Confirming Clean Initial State" @@ -70,16 +70,51 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p ${HF_CACHE} HF_MOUNT="/root/.cache/huggingface" -docker run \ +commands=$@ +PARALLEL_JOB_COUNT=8 +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +if [[ $commands == *"--shard-id="* ]]; then + for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do + #replace shard arguments + commands=${@//"--shard-id= "/"--shard-id=${GPU} "} + commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + docker run \ --device /dev/kfd --device /dev/dri \ --network host \ --shm-size=16gb \ --rm \ - -e HIP_VISIBLE_DEVICES=0 \ + -e HIP_VISIBLE_DEVICES=${GPU} \ -e HF_TOKEN \ -v ${HF_CACHE}:${HF_MOUNT} \ -e HF_HOME=${HF_MOUNT} \ - --name ${container_name} \ + --name ${container_name}_${GPU} \ ${image_name} \ - /bin/bash -c "${@}" - + /bin/bash -c "${commands}" \ + |& while read -r line; do echo ">>Shard $GPU: $line"; done & + PIDS+=($!) + done + #wait for all processes to finish and collect exit codes + for pid in ${PIDS[@]}; do + wait ${pid} + STATUS+=($?) + done + for st in ${STATUS[@]}; do + if [[ ${st} -ne 0 ]]; then + echo "One of the processes failed with $st" + exit ${st} + fi + done +else + docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --shm-size=16gb \ + --rm \ + -e HIP_VISIBLE_DEVICES=0 \ + -e HF_TOKEN \ + -v ${HF_CACHE}:${HF_MOUNT} \ + -e HF_HOME=${HF_MOUNT} \ + --name ${container_name} \ + ${image_name} \ + /bin/bash -c "${commands}" +fi diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 86eddb576c42a..65e1862ce8181 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -218,9 +218,9 @@ steps: - pytest -v -s spec_decode - label: LoRA Test %N # 30min each + mirror_hardwares: [amd] source_file_dependencies: - vllm/lora - - csrc/punica - tests/lora command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 @@ -360,7 +360,6 @@ steps: num_gpus: 4 source_file_dependencies: - vllm/lora - - csrc/punica - tests/lora/test_long_context commands: # FIXIT: find out which code initialize cuda before running the test diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 709246179bfe4..58cac3156c9c1 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -1,7 +1,10 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest +from vllm.utils import is_hip MODEL_PATH = "google/gemma-7b" @@ -28,6 +31,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 2370c693e9534..133e0d4514a6d 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -7,6 +7,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.utils import is_hip from .conftest import cleanup @@ -17,12 +18,23 @@ class ModelWithQuantization: quantization: str -MODELS: List[ModelWithQuantization] = [ - ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="AWQ"), - ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="GPTQ"), -] +MODELS: List[ModelWithQuantization] +#AWQ quantization is currently not supported in ROCm. +if is_hip(): + MODELS = [ + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), + ] +else: + MODELS = [ + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), + ] def do_sample(llm: vllm.LLM, From 561d6f8077c54c7af5dbf2ed92131ce9f7d9b56b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 4 Sep 2024 13:05:50 -0700 Subject: [PATCH 09/77] [CI] Change test input in Gemma LoRA test (#8163) --- tests/lora/test_gemma.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 58cac3156c9c1..f7c1d4f041c12 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -13,7 +13,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: prompts = [ "Quote: Imagination is", "Quote: Be yourself;", - "Quote: So many books,", + "Quote: Painting is poetry that is seen rather than felt,", ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( @@ -41,7 +41,8 @@ def test_gemma_lora(gemma_lora_files): expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n", - "so little time\nAuthor: Frank Zappa\n", + "and poetry is painting that is felt rather than seen.\n" + "Author: Leonardo da Vinci\n", ] output1 = do_sample(llm, gemma_lora_files, lora_id=1) From e02ce498be2e11a165803d4590588ba98f129797 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 4 Sep 2024 15:18:13 -0500 Subject: [PATCH 10/77] [Feature] OpenAI-Compatible Tools API + Streaming for Hermes & Mistral models (#5649) Co-authored-by: constellate Co-authored-by: Kyle Mistele --- .buildkite/test-pipeline.yaml | 10 + .../serving/openai_compatible_server.md | 58 ++- ...penai_chat_completion_client_with_tools.py | 162 +++++++++ examples/tool_chat_template_hermes.jinja | 129 +++++++ examples/tool_chat_template_mistral.jinja | 86 +++++ .../tool_chat_template_mistral_parallel.jinja | 94 +++++ requirements-common.txt | 1 + tests/tool_use/__init__.py | 0 tests/tool_use/conftest.py | 32 ++ tests/tool_use/test_chat_completions.py | 143 ++++++++ tests/tool_use/test_parallel_tool_calls.py | 193 ++++++++++ tests/tool_use/test_tool_calls.py | 192 ++++++++++ tests/tool_use/utils.py | 215 +++++++++++ vllm/entrypoints/chat_utils.py | 101 ++++- vllm/entrypoints/openai/api_server.py | 8 +- vllm/entrypoints/openai/cli_args.py | 18 + vllm/entrypoints/openai/protocol.py | 125 ++++++- vllm/entrypoints/openai/serving_chat.py | 275 ++++++++++++-- .../openai/serving_tokenization.py | 6 +- .../openai/tool_parsers/__init__.py | 5 + .../tool_parsers/abstract_tool_parser.py | 58 +++ .../openai/tool_parsers/hermes_tool_parser.py | 344 ++++++++++++++++++ .../tool_parsers/mistral_tool_parser.py | 293 +++++++++++++++ vllm/entrypoints/openai/tool_parsers/utils.py | 87 +++++ .../guided_decoding/__init__.py | 5 +- .../guided_decoding/outlines_decoding.py | 31 +- 26 files changed, 2588 insertions(+), 83 deletions(-) create mode 100644 examples/openai_chat_completion_client_with_tools.py create mode 100644 examples/tool_chat_template_hermes.jinja create mode 100644 examples/tool_chat_template_mistral.jinja create mode 100644 examples/tool_chat_template_mistral_parallel.jinja create mode 100644 tests/tool_use/__init__.py create mode 100644 tests/tool_use/conftest.py create mode 100644 tests/tool_use/test_chat_completions.py create mode 100644 tests/tool_use/test_parallel_tool_calls.py create mode 100644 tests/tool_use/test_tool_calls.py create mode 100644 tests/tool_use/utils.py create mode 100644 vllm/entrypoints/openai/tool_parsers/__init__.py create mode 100644 vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 65e1862ce8181..d50d8f32a816d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -92,6 +92,7 @@ steps: - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py + - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -271,6 +272,15 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use + ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index b2acde390083c..eb4ea0fb5655e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -110,6 +110,14 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :func: create_parser_for_docs :prog: vllm serve ``` +## Tool Calling in the Chat Completion API +### Named Function Calling +vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. ### Config file @@ -140,10 +148,52 @@ The order of priorities is `command line > config file values > defaults`. ## Tool calling in the chat completion API vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. -To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. - -It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. + +### Automatic Function Calling +To enable this feature, you should set the following flags: +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +deems appropriate. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +will continue to be added in the future. +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat +template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) + +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! + +#### Hermes Models +All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. +* `NousResearch/Hermes-2-Pro-*` +* `NousResearch/Hermes-2-Theta-*` +* `NousResearch/Hermes-3-*` + + +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. + +Flags: `--tool-call-parser hermes` + +#### Mistral Models +Supported models: +* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) +* Additional mistral function-calling models are compatible as well. + +Known issues: +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. Since an exception is thrown when this condition +is not met, the following additional chat templates are provided: + +* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) +* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +when tools are provided, that results in much better reliability when working with parallel tool calling. + + +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py new file mode 100644 index 0000000000000..2bbe42b6bd2ef --- /dev/null +++ b/examples/openai_chat_completion_client_with_tools.py @@ -0,0 +1,162 @@ +""" +Set up this example by starting a vLLM OpenAI-compatible server with tool call +options enabled. For example: + +IMPORTANT: for mistral, you must use one of the provided mistral tool call +templates, or your own - the model default doesn't work for tool calls with vLLM +See the vLLM docs on OpenAI server & tool calling for more details. + +vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ + --chat-template examples/tool_chat_template_mistral.jinja \ + --enable-auto-tool-choice --tool-call-parser mistral + +OR +vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ + --chat-template examples/tool_chat_template_hermes.jinja \ + --enable-auto-tool-choice --tool-call-parser hermes +""" +import json + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) + +print("Chat completion results:") +print(chat_completion) +print("\n\n") + +tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) + +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + +arguments = [] +tool_call_idx = -1 +for chunk in chunks: + + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + + if tool_call.index != tool_call_idx: + if tool_call_idx >= 0: + print( + f"streamed tool call arguments: {arguments[tool_call_idx]}" + ) + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print(f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + +if len(arguments): + print(f"streamed tool call arguments: {arguments[-1]}") + +print("\n\n") + +messages.append({ + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls +}) + + +# Now, simulate a tool call +def get_current_weather(city: str, state: str, unit: 'str'): + return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's.") + + +available_tools = {"get_current_weather": get_current_weather} + +completion_tool_calls = chat_completion.choices[0].message.tool_calls +for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print(result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) + +chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) +print("\n\n") +print(chat_completion_2) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja new file mode 100644 index 0000000000000..b18b463032d4f --- /dev/null +++ b/examples/tool_chat_template_hermes.jinja @@ -0,0 +1,129 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- if tools is iterable and tools | length > 0 %} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +{%- endif %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|>' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and message.tool_calls is defined %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"}' }} + {{- ', ' }} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {%- if not loop.last %} + {{- '\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja new file mode 100644 index 0000000000000..49691f59c2f2c --- /dev/null +++ b/examples/tool_chat_template_mistral.jinja @@ -0,0 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/examples/tool_chat_template_mistral_parallel.jinja b/examples/tool_chat_template_mistral_parallel.jinja new file mode 100644 index 0000000000000..a294cbfd026be --- /dev/null +++ b/examples/tool_chat_template_mistral_parallel.jinja @@ -0,0 +1,94 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- if tools is defined %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/requirements-common.txt b/requirements-common.txt index 4c5b681a0d5ab..447fd32311c09 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,6 +20,7 @@ lm-format-enforcer == 0.10.6 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec gguf == 0.9.1 diff --git a/tests/tool_use/__init__.py b/tests/tool_use/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py new file mode 100644 index 0000000000000..ab6a29eba1b3f --- /dev/null +++ b/tests/tool_use/conftest.py @@ -0,0 +1,32 @@ +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer + +from .utils import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="session") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py new file mode 100644 index 0000000000000..038ff81d2b674 --- /dev/null +++ b/tests/tool_use/test_chat_completions.py @@ -0,0 +1,143 @@ +from typing import List + +import openai +import pytest + +from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + assert stop_reason != "tool_calls" + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert not role_sent + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py new file mode 100644 index 0000000000000..b03b5a2075a6c --- /dev/null +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) + + +# test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + + +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py new file mode 100644 index 0000000000000..c3abe9e1f5060 --- /dev/null +++ b/tests/tool_use/test_tool_calls.py @@ -0,0 +1,192 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) + + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py new file mode 100644 index 0000000000000..8ec9b05b2c521 --- /dev/null +++ b/tests/tool_use/utils.py @@ -0,0 +1,215 @@ +from typing import Dict, List + +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from typing_extensions import TypedDict + +from tests.utils import VLLM_PATH + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +# universal args for all models go here. also good if you need to test locally +# and change type or KV cache quantization or something. +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "arguments": [ + "--tool-call-parser", "hermes", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", "mistral", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), + "--ignore-patterns=\"consolidated.safetensors\"" + ] + } +} + +WEATHER_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +} + +SEARCH_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] + +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f205a99920892..9a7493649c795 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,23 +1,28 @@ import asyncio import codecs +import json from abc import ABC, abstractmethod from collections import defaultdict -from functools import lru_cache +from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, - Mapping, Optional, Tuple, TypeVar, Union) + Mapping, Optional, Tuple, TypeVar, Union, cast) # yapf conflicts with isort for this block # yapf: disable -from openai.types.chat import ChatCompletionContentPartImageParam +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict, TypeAdapter +from pydantic import ConfigDict from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -54,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam, ] + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -72,15 +78,33 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): same role. """ + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] # TODO: Make fields ReadOnly once mypy supports it -class ConversationMessage(TypedDict): - role: str - content: str +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Optional[str] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" ModalityStr = Literal["image", "audio"] @@ -319,9 +343,11 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], return "\n".join(missing_placeholders + [text_prompt]) -_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) -_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) -_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) def _parse_chat_message_content_parts( @@ -336,10 +362,10 @@ def _parse_chat_message_content_parts( for part in parts: part_type = part["type"] if part_type == "text": - text = _TextParser.validate_python(part)["text"] + text = _TextParser(part)["text"] texts.append(text) elif part_type == "image_url": - image_url = _ImageParser.validate_python(part)["image_url"] + image_url = _ImageParser(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( @@ -348,7 +374,7 @@ def _parse_chat_message_content_parts( mm_parser.parse_image(image_url["url"]) elif part_type == "audio_url": - audio_url = _AudioParser.validate_python(part)["audio_url"] + audio_url = _AudioParser(part)["audio_url"] mm_parser.parse_audio(audio_url["url"]) else: @@ -363,6 +389,11 @@ def _parse_chat_message_content_parts( return [ConversationMessage(role=role, content=text_prompt)] +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, @@ -371,16 +402,34 @@ def _parse_chat_message_content( content = message.get("content") if content is None: - return [] - if isinstance(content, str): - return [ConversationMessage(role=role, content=content)] + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] - return _parse_chat_message_content_parts( + result = _parse_chat_message_content_parts( role, content, # type: ignore mm_tracker, ) + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + def parse_chat_messages( messages: List[ChatCompletionMessageParam], @@ -428,6 +477,20 @@ def apply_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in conversation: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for i in range(len(message["tool_calls"])): + args: str = message["tool_calls"][i]["function"]["arguments"] + parsed_args: Dict = json.loads(args) + message["tool_calls"][i]["function"]["arguments"] = parsed_args + prompt = tokenizer.apply_chat_template( conversation=conversation, chat_template=chat_template, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7632e8aa5e32e..728a2e5232d9b 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -233,7 +233,7 @@ def mount_metrics(app: FastAPI): metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) @@ -283,11 +283,14 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -422,7 +425,8 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser) openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 94742838b421c..7ccee0b6b55b7 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["mistral", "hermes"], + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0954b81595ef5..ff9c3690672b6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch +from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors @@ -35,6 +36,26 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") @@ -145,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Literal["none"], + tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -328,6 +352,9 @@ def check_logprobs(cls, data): @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, @@ -339,21 +366,61 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none": + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @model_validator(mode="before") @classmethod - def check_tool_choice(cls, data): - if "tool_choice" in data and data["tool_choice"] != "none": - if not isinstance(data["tool_choice"], dict): - raise ValueError("Currently only named tools are supported.") + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and "tools" in data: + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError( "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function["name"] + if not specified_function_name: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") return data @@ -413,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel): ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), + description="If specified, the output will follow the JSON schema.", ) guided_regex: Optional[str] = Field( default=None, @@ -633,9 +700,41 @@ class ToolCall(OpenAIBaseModel): function: FunctionCall +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +# the initial delta that gets sent once a new tool call is started; +class InitialDeltaToolCall(DeltaToolCall): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + class ChatMessage(OpenAIBaseModel): role: str - content: str + content: Optional[str] = None tool_calls: List[ToolCall] = Field(default_factory=list) @@ -657,7 +756,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None @@ -674,7 +775,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + tool_calls: List[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a3bc0bb7b3554..78f355228012f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,8 @@ import asyncio +import json import time -from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) from typing import Sequence as GenericSequence from typing import Union @@ -18,15 +20,18 @@ ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) +from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, + MistralToolParser, + ToolParser) from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -38,19 +43,19 @@ class OpenAIServingChat(OpenAIServing): - def __init__( - self, - async_engine_client: AsyncEngineClient, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - ): + def __init__(self, + async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None): super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, @@ -60,10 +65,27 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role - - # If this is None we use the tokenizer's default chat template + self.use_tool_use_model_template = False self.chat_template = load_chat_template(chat_template) + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + if tool_parser == "mistral": + self.tool_parser = MistralToolParser + elif tool_parser == "hermes": + self.tool_parser = Hermes2ProToolParser + else: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -76,11 +98,10 @@ async def create_chat_completion( for the API specification. This API mimics the OpenAI ChatCompletion API. - NOTE: Currently we do not support the following feature: - - function_call (Users should implement this by themselves) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) return error_check_ret try: @@ -119,6 +140,20 @@ async def create_chat_completion( logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + # "auto" tools requires --enable-auto-tool-choice + # and --tool-call-parser + if request.tool_choice == "auto" and not ( + self.enable_auto_tools and self.tool_parser is not None): + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") + request_id = f"chat-{random_uuid()}" try: guided_decode_logits_processor = ( @@ -187,6 +222,7 @@ async def create_chat_completion( if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation, tokenizer) + try: return await self.chat_completion_full_generator( request, result_generator, request_id, conversation, tokenizer) @@ -219,6 +255,9 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + tool_parser: Optional[ToolParser] = self.tool_parser( + tokenizer) if self.tool_parser else None + try: async for res in result_generator: # We need to do it here, because if there are exceptions in @@ -228,6 +267,9 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, @@ -240,14 +282,18 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # if usage should be included if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): + # if continuous usage stats are requested, add it + if request.stream_options.continuous_usage_stats: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens) chunk.usage = usage + # otherwise don't else: chunk.usage = None @@ -257,7 +303,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content = "" + last_msg_content: Optional[str] = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: @@ -298,6 +344,7 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: + i = output.index if finish_reason_sent[i]: @@ -320,20 +367,50 @@ async def chat_completion_stream_generator( logprobs = None delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + delta_message: Optional[DeltaMessage] = None - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: + # handle streaming deltas for tools with named tool_choice + if (request.tool_choice and type(request.tool_choice) is + ChatCompletionNamedToolChoiceParam): delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( + DeltaToolCall(function=DeltaFunctionCall( name=request.tool_choice.function.name, - arguments=delta_text)) + arguments=delta_text), + index=i) ]) + + # handle streaming deltas for tools with "auto" tool choice + elif (self._should_stream_with_auto_tool_parsing(request) + and tool_parser): + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_texts[i], + current_text=output.text, + delta_text=delta_text, + previous_token_ids= \ + output.token_ids[ + :-1 * len(delta_token_ids) + ], + current_token_ids=output.token_ids, + delta_token_ids=delta_token_ids + ) + ) + + # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) + # set the previous values for the next iteration + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + if output.finish_reason is None: # Send token-by-token response for each request.n @@ -348,6 +425,8 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats): @@ -365,14 +444,55 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + + # if the model is finished generating else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + if tool_parser: + index = len( + tool_parser.prev_tool_call_arr) - 1 if len( + tool_parser.prev_tool_call_arr) > 0 else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {})) + + # get what we've streamed so for for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason=output.finish_reason + if not (tool_parser + and len(tool_parser.prev_tool_call_arr)) + else "tool_calls", stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -398,6 +518,8 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): final_usage = UsageInfo( @@ -419,6 +541,7 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error + logger.error("error in chat completion stream generator: %s", e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished @@ -463,8 +586,21 @@ async def chat_completion_full_generator( else: logprobs = None - if request.tool_choice and type( + # by default, tools are not used. + tools_called = False + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if not (self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, content=output.text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( role=role, content="", @@ -473,14 +609,47 @@ async def chat_completion_full_generator( name=request.tool_choice.function.name, arguments=output.text)) ]) + tools_called = True + + # if the request doesn't use tool choice + # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + tool_parser = self.tool_parser(tokenizer) + tool_call_info = tool_parser.extract_tool_calls(output.text) + tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, content=output.text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason="tool_calls" if tools_called else + output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason) choices.append(choice_data) @@ -488,10 +657,11 @@ async def chat_completion_full_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content = conversation[-1]["content"] or "" for choice in choices: - full_message = last_msg_content + choice.message.content + full_message = last_msg_content + (choice.message.content + or "") choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) @@ -574,3 +744,38 @@ def _create_chat_logprobs( )) return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + and output.finish_reason is not None + ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c3c0d52072cd3..69a5ad5b62cfa 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -43,7 +43,11 @@ def __init__( request_logger=request_logger) # If this is None we use the tokenizer's default chat template - self.chat_template = load_chat_template(chat_template) + # the list of commonly-used chat template names for HF named templates + hf_chat_templates: List[str] = ['default', 'tool_use'] + self.chat_template = chat_template \ + if chat_template in hf_chat_templates \ + else load_chat_template(chat_template) async def create_tokenize( self, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000..5d5d53784fedf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,5 @@ +from .abstract_tool_parser import ToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .mistral_tool_parser import MistralToolParser + +__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000000000..b0807e6f1e782 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Sequence, Union + +from vllm.entrypoints.openai.protocol import (DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000000000..7afbca7162edf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,344 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_start_token] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_end_token] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug("Generating text content! skipping tool parsing.") + if delta_text != self.tool_call_end_token: + return DeltaMessage(content=delta_text) + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case - we haven't sent the initial delta with the tool call ID + # (it will be sent) + if not self.current_tool_initial_sent: + self.current_tool_initial_sent = True + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + elif not self.current_tool_name_sent: + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + args_delta_start_loc = cur_arguments_json.index(delta_text) \ + + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between\n%s", cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += argument_diff + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000..d48770c792e98 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,293 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + self.model_tokenizer = self.model_tokenizer.tokenizer + else: + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + + # use a regex to find the tool call. remove the BOT token + # and make sure to replace single quotes with double quotes + raw_tool_call = self.tool_call_regex.findall( + model_output.replace(self.bot_token, ""))[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + print("ERROR", e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token_id not in current_token_ids: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool initial data incl. the id, type=function + # and idx not sent, send that + if not self.current_tool_initial_sent: + self.current_tool_initial_sent = True + delta = DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000..db7fc5259fc4e --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,87 @@ +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index f9fcdead980a2..7161e83952a3d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -59,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request) is CompletionRequest: return request - # user has chosen to not use any tool - if request.tool_choice == "none": + # user has chosen to not use any tool, + # OR is allowing the model to choose a tool. + if request.tool_choice == "none" or request.tool_choice == "auto": return request # user has chosen to use a named tool diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bfc658ef7d26b..e1f5b380120c5 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -8,8 +8,9 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( @@ -101,16 +102,30 @@ def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest, GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: + # if the request is a chat completion request, AND the tool choice is a + # named tool choice, do guided decoding + # using that tool as the JSON schema + if isinstance(request, ChatCompletionRequest) and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): + # Guided generation for tools/functions parameters + if request.tool_choice.type == "function": + for tool in request.tools: + if (tool.type == "function" and tool.function.name + == request.tool_choice.function.name): + json = json_dumps(tool.function.parameters, sort_keys=True) + return json, GuidedDecodingMode.JSON + return None, None - if request.guided_json: - json = request.guided_json - if isinstance(json, dict): + elif request.guided_json: + if isinstance(request.guided_json, dict): # turn dict into hashable string - json = json_dumps(json) - elif isinstance(json, BaseModel): + json = json_dumps(request.guided_json) + elif isinstance(request.guided_json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same - json = str(json.__signature__) + json = str(request.guided_json.__signature__) + else: + json = request.guided_json return json, GuidedDecodingMode.JSON elif request.guided_regex: return request.guided_regex, GuidedDecodingMode.REGEX From 77d9e514a2284d5d0bd34b1518b9483ae7d8a05a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 4 Sep 2024 13:23:22 -0700 Subject: [PATCH 11/77] [MISC] Replace input token throughput with total token throughput (#8164) Co-authored-by: Michael Goin --- benchmarks/benchmark_serving.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e38ceaa222956..84f366bdba387 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -56,8 +56,8 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float - input_throughput: float output_throughput: float + total_token_throughput: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float @@ -283,8 +283,8 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, - input_throughput=total_input / dur_s, output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, @@ -426,10 +426,10 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", - metrics.input_throughput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) result = { "duration": benchmark_duration, @@ -437,8 +437,8 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], From 008cf886c9361e696f70a15a282d72b58686468a Mon Sep 17 00:00:00 2001 From: Harsha vardhan manoj Bikki <39381063+hbikki@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:33:43 -0700 Subject: [PATCH 12/77] =?UTF-8?q?[Neuron]=20Adding=20support=20for=20addin?= =?UTF-8?q?g/=20overriding=20neuron=20configuration=20a=E2=80=A6=20(#8062)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Harsha Bikki --- ...line_inference_neuron_int8_quantization.py | 50 ++++++++++++++ vllm/config.py | 69 +++++++++++-------- vllm/engine/arg_utils.py | 17 ++++- vllm/engine/llm_engine.py | 2 + .../layers/quantization/__init__.py | 3 + .../layers/quantization/neuron_quant.py | 67 ++++++++++++++++++ vllm/model_executor/model_loader/neuron.py | 65 ++++++++++++++--- vllm/worker/neuron_model_runner.py | 12 +++- 8 files changed, 243 insertions(+), 42 deletions(-) create mode 100644 examples/offline_inference_neuron_int8_quantization.py create mode 100644 vllm/model_executor/layers/quantization/neuron_quant.py diff --git a/examples/offline_inference_neuron_int8_quantization.py b/examples/offline_inference_neuron_int8_quantization.py new file mode 100644 index 0000000000000..8ec17e3400953 --- /dev/null +++ b/examples/offline_inference_neuron_int8_quantization.py @@ -0,0 +1,50 @@ +import os + +from vllm import LLM, SamplingParams + +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" +# Quantizes neuron model weight to int8 , +# The default config for quantization is int8 dtype. +os.environ['NEURON_QUANT_DTYPE'] = "s8" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM( + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + max_num_seqs=8, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in transformers-neuronx. + # TODO(liangfu): Support paged-attention in transformers-neuronx. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + quantization="neuron_quant", + override_neuron_config={ + "cast_logits_dtype": "bfloat16", + }, + tensor_parallel_size=2) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index b84d91d402370..9b3f4f9206300 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,8 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, - Type, Union) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, + Optional, Tuple, Type, Union) import torch from transformers import PretrainedConfig @@ -115,35 +115,39 @@ class ModelConfig: the model name will be the same as `model`. limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. """ def __init__( - self, - model: str, - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - ) -> None: + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -227,6 +231,9 @@ def __init__( limit_mm_per_prompt) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + + self.override_neuron_config = override_neuron_config if is_neuron( + ) else None self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -275,6 +282,7 @@ def _verify_quantization(self) -> None: "experts_int8" ] tpu_supported_quantization = ["tpu_int8"] + neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -329,6 +337,11 @@ def _verify_quantization(self) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True + if is_neuron( + ) and self.quantization not in neuron_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in Neuron Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8dbe6504d21bd..f0b866db64324 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -2,8 +2,8 @@ import dataclasses import json from dataclasses import dataclass -from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, - Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Type, Union) import torch @@ -149,6 +149,7 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False + override_neuron_config: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -742,6 +743,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_async_output_proc, help="Disable async output processing. This may result in " "lower performance.") + parser.add_argument( + '--override-neuron-config', + type=lambda configs: { + str(key): value + for key, value in + (config.split(':') for config in configs.split(',')) + }, + default=None, + help="override or set neuron device configuration.") + return parser @classmethod @@ -802,7 +813,7 @@ def create_engine_config(self) -> EngineConfig: served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, - ) + override_neuron_config=self.override_neuron_config) cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7da4f7b25db9e..50dcb6937eb6f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -214,6 +214,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " @@ -232,6 +233,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.override_neuron_config, model_config.rope_scaling, model_config.rope_theta, model_config.tokenizer_revision, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 95b160f4287f9..c6fb6ca0d2e01 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -22,6 +22,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.neuron_quant import ( + NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig @@ -46,6 +48,7 @@ "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, } diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py new file mode 100644 index 0000000000000..2624981f6a614 --- /dev/null +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -0,0 +1,67 @@ +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +from torch.nn import Module + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] + + +class NeuronQuantConfig(QuantizationConfig): + """Int8 Quantization Config class for Neuron Backend.""" + + def __init__( + self, + dequant_dtype: str = "f16", + quantize_method: str = "vector_dynamic", + ) -> None: + self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") + if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: + raise ValueError( + f"Neuron quantization datatype {self.quant_dtype} is not valid," + f"the quantization datatype should match one of the below types" + f"{SUPPORTED_QUANT_DTYPE_LIST}") + self.dequant_dtype = dequant_dtype + self.quantize_method = quantize_method + + def get_name(self) -> str: + return "neuron_quant" + + def get_supported_act_dtypes(self) -> List[str]: + return SUPPORTED_QUANT_DTYPE_LIST + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with Neuron Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": + quantize_method = cls.get_from_keys(config, ["quantize_method"]) + dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) + return cls(dequant_dtype=dequant_dtype, + quantize_method=quantize_method) + + def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: + if find_spec("transformers_neuronx") is not None: + return self.get_quantization_config() + else: + raise NotImplementedError( + "Neuron Quantization is only supported through" + " transformers_neuronx.") + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_quantization_config(self): + from transformers_neuronx.config import QuantizationConfig + return QuantizationConfig(quant_dtype=self.quant_dtype, + dequant_dtype=self.dequant_dtype, + quantize_method=self.quantize_method) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 7396ac833e782..594ae442ef328 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -81,8 +82,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) split_model_dir = f"{model_name_or_path}-split" - if os.path.isdir(os.path.join(model_name_or_path, - "pytorch_model.bin")): + if _is_pretrained_neuron_checkpoint(model_name_or_path): split_model_dir = model_name_or_path elif not os.path.exists(f"{model_name_or_path}-split"): hf_model_cls = getattr(transformers, hf_model_cls_name) @@ -97,6 +97,23 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() +def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool: + # Checking if the neuron checkpoint is saved in the old format. + if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): + return True + # Checking if the neuron checkpoint is saved in the new format. + pretrained_split_files = ["config.json", "generation_config.json"] + pretrained_split_format = ".safetensors" + for file in pretrained_split_files: + file_path = os.path.join(model_name_or_path, file) + if not os.path.isfile(file_path): + return False + for file in os.listdir(model_name_or_path): + if file.endswith(pretrained_split_format): + return True + return False + + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -119,19 +136,51 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]: return buckets_list +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + quant_config = dict( + dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + quantize_method="vector_dynamic") + neuron_quantization_config_builder = lambda quant: get_quantization_config( + quant).from_config(quant_config).get_quant_method(None, "") + # TODO: Add Paged attention config to the default neuron arguments. + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + quant=neuron_quantization_config_builder(model_config.quantization) + if model_config.quantization else None, + continuous_batching=continuous_batching_config, + weight_tiling=bool(model_config.quantization)) + return default_neuron_args + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + from transformers_neuronx.config import NeuronConfig + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return NeuronConfig(**default_neuron_config) + + def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - from transformers_neuronx.config import (ContinuousBatchingConfig, - NeuronConfig) # Create a model instance. model = NeuronCasualLM(model_config.hf_config) - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - neuron_config = NeuronConfig( - continuous_batching=continuous_batching_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", [scheduler_config.max_model_len]) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f3defffdfa520..0cf7445d4388d 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -76,9 +77,14 @@ def __init__( self.model: nn.Module # initialize after load_model. def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + if find_spec("transformers_neuronx") is not None: + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") def _prepare_prompt( self, From 32e7db25365415841ebc7c4215851743fbb1bad1 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 4 Sep 2024 16:34:27 -0700 Subject: [PATCH 13/77] Bump version to v0.6.0 (#8166) --- vllm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/version.py b/vllm/version.py index 052eb76b5873c..039f6369b8ed5 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -9,4 +9,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.5.5" +__version__ = "0.6.0" From e01c2beb7d1df1f388051f083a20ae9c0d552027 Mon Sep 17 00:00:00 2001 From: Maureen McElaney Date: Wed, 4 Sep 2024 19:50:13 -0400 Subject: [PATCH 14/77] [Doc] [Misc] Create CODE_OF_CONDUCT.md (#8161) --- CODE_OF_CONDUCT.md | 128 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000..f801b5f8f5513 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ + +# vLLM Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline/IRL event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement in the #code-of-conduct +channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1, available at +[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). + +For answers to common questions about this code of conduct, see the +[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at +[Contributor Covenant translations](https://www.contributor-covenant.org/translations). + From 1afc931987d0c0e12bb3fde7908e768222916385 Mon Sep 17 00:00:00 2001 From: William Lin Date: Wed, 4 Sep 2024 17:35:36 -0700 Subject: [PATCH 15/77] [bugfix] >1.43 constraint for openai (#8169) Co-authored-by: Michael Goin --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 447fd32311c09..e430753357ca0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -9,7 +9,7 @@ tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi aiohttp -openai >= 1.0 # Ensure modern openai package (ensure types module present) +openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] pydantic >= 2.8 # Required for OpenAI server. pillow # Required for image processing From 4624d98dbdd6f29a3d8ba7a86d93bde730ef5f7d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 4 Sep 2024 20:31:48 -0700 Subject: [PATCH 16/77] [Misc] Clean up RoPE forward_native (#8076) --- .../model_executor/layers/rotary_embedding.py | 95 ++++--------------- 1 file changed, 19 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index c5a0278e485d4..d323f6cc432a2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,6 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -48,21 +47,29 @@ def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool, ) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] cos: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. """ - orig_dtype = x.dtype - x = x.float() - x1, x2 = torch.chunk(x, 2, dim=-1) - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1).to(orig_dtype) + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) class RotaryEmbedding(CustomOp): @@ -87,10 +94,9 @@ def __init__( cache = self._compute_cos_sin_cache() cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.use_native2 = current_platform.is_tpu() and is_neox_style - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -119,59 +125,7 @@ def forward_native( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation equivalent to forward(). - - This method mimics the implementation of the custom CUDA kernel - used in `forward_cuda()`. - """ - query = query.view(*query.shape[:-1], -1, self.head_size) - key = key.view(*key.shape[:-1], -1, self.head_size) - - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] - - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device, dtype=query.dtype) - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - query = query.flatten(-2) - key = key.flatten(-2) - return query, key - - def forward_native2( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Another PyTorch-native implementation of forward(). - - This method might perform better than `forward_native()` when compiled. - """ + """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets positions = positions.flatten() @@ -183,14 +137,14 @@ def forward_native2( query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin) + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin) + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -203,7 +157,7 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. @@ -240,17 +194,6 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key - def forward_tpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - forward_fn = (self.forward_native2 - if self.use_native2 else self.forward_native) - return forward_fn(positions, query, key, offsets) - def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" From ba262c4e5aa9fa753c8cedfaea5c42941184a0db Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Wed, 4 Sep 2024 20:33:12 -0700 Subject: [PATCH 17/77] [ci] Mark LoRA test as soft-fail (#8160) Signed-off-by: kevin --- .buildkite/test-pipeline.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d50d8f32a816d..b2874750a777e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -368,6 +368,7 @@ steps: - label: LoRA Long Context (Distributed) # 11min # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 + soft_fail: true source_file_dependencies: - vllm/lora - tests/lora/test_long_context From e39ebf5cf5ec8f7449d633b6428333a99a206a1c Mon Sep 17 00:00:00 2001 From: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Date: Wed, 4 Sep 2024 22:12:26 -0700 Subject: [PATCH 18/77] [Core/Bugfix] Add query dtype as per FlashInfer API requirements. (#8173) --- tests/kernels/test_flashinfer.py | 3 ++- vllm/attention/backends/flashinfer.py | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 67f12cf1ee08e..696cc0c6cdf10 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv( head_size, block_size, "NONE", - data_type=dtype) + data_type=dtype, + q_data_type=dtype) output = wrapper.forward(query, kv_cache_fp8, logits_soft_cap=soft_cap, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index aa9d4a71dbf87..7aec8203eb1e5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): query_start_loc=query_start_loc_host, device=self.runner.device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, prefill_wrapper=None) @@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata): page_size: Optional[int] = None # The data type of the paged kv cache data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None device: torch.device = torch.device("cuda") is_profile_run: bool = False @@ -353,7 +356,10 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - data_type=self.data_type) + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, device=device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=use_captured_graph, is_profile_run=self.is_profile_run) From 288a938872cc3c6150a486aaa15a3b5dcadf42cc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 5 Sep 2024 18:51:53 +0800 Subject: [PATCH 19/77] [Doc] Indicate more information about supported modalities (#8181) --- .buildkite/test-pipeline.yaml | 1 + docs/source/getting_started/debugging.rst | 2 +- docs/source/getting_started/quickstart.rst | 6 +- docs/source/models/supported_models.rst | 21 +-- docs/source/models/vlm.rst | 123 +++++++++++++----- ...e_inference_vision_language_multi_image.py | 95 ++++++++++++++ examples/openai_vision_api_client.py | 9 +- 7 files changed, 206 insertions(+), 51 deletions(-) create mode 100644 examples/offline_inference_vision_language_multi_image.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b2874750a777e..d0317b2fc48c9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -158,6 +158,7 @@ steps: - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 offline_inference_vision_language.py + - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 117a9dd666481..31ecca1332e5d 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -21,7 +21,7 @@ If you have already taken care of the above issues, but the vLLM instance still With more logging, hopefully you can find the root cause of the issue. -If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error. +If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error. Here are some common issues that can cause hangs: diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 89bdc247c5e8e..80b19ac672936 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -24,7 +24,9 @@ Offline Batched Inference We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts. -Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process. +Import :class:`~vllm.LLM` and :class:`~vllm.SamplingParams` from vLLM. +The :class:`~vllm.LLM` class is the main class for running offline inference with vLLM engine. +The :class:`~vllm.SamplingParams` class specifies the parameters for the sampling process. .. code-block:: python @@ -42,7 +44,7 @@ Define the list of input prompts and the sampling parameters for generation. The ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model `_. The list of supported models can be found at :ref:`supported models `. +Initialize vLLM's engine for offline inference with the :class:`~vllm.LLM` class and the `OPT-125M model `_. The list of supported models can be found at :ref:`supported models `. .. code-block:: python diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 2c20b6e48407d..084be1e2a4f8e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -194,12 +194,12 @@ Multimodal Language Models * - Architecture - Models - - Supported Modalities + - Modalities - Example HuggingFace Models - :ref:`LoRA ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - - Image + - Image\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - * - :code:`ChameleonForConditionalGeneration` @@ -214,40 +214,43 @@ Multimodal Language Models - * - :code:`InternVLChatModel` - InternVL2 - - Image + - Image\ :sup:`E` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image + - Image\ :sup:`E` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - - Image + - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - - Image + - Image\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - - Image + - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - * - :code:`MiniCPMV` - MiniCPM-V - - Image + - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - * - :code:`UltravoxModel` - Ultravox - - Audio + - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - +| :sup:`E` Pre-computed embeddings can be inputted for this modality. +| :sup:`+` Multiple items can be inputted per text prompt for this modality. + .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 236e37b51d470..08db891665044 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -9,26 +9,23 @@ This document shows you how to run and serve these models using vLLM. .. important:: We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation. - Currently, the support for vision language models on vLLM has the following limitations: - - * Only single image input is supported per text prompt. - We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub `_ if you have any feedback or feature requests. -Offline Batched Inference -------------------------- +Offline Inference +----------------- + +Single-image input +^^^^^^^^^^^^^^^^^^ -To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine. +The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models. .. code-block:: python llm = LLM(model="llava-hf/llava-1.5-7b-hf") -.. important:: +.. note:: We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that - internally for each model. - + the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: @@ -86,61 +83,117 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI A code example can be found in `examples/offline_inference_vision_language.py `_. +Multi-image input +^^^^^^^^^^^^^^^^^ -Online OpenAI Vision API Compatible Inference ----------------------------------------------- +Multi-image input is only supported for a subset of VLMs, as shown :ref:`here `. -You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. +To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class. -.. note:: - Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be - added in the future. +.. code-block:: python -Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server. + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + trust_remote_code=True, # Required to load Phi-3.5-vision + max_model_len=4096, # Otherwise, it may not fit in smaller GPUs + limit_mm_per_prompt={"image": 2}, # The maximum number to accept + ) -.. important:: - Since OpenAI Vision API is based on `Chat `_ API, a chat template - is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the - HuggingFace Llava chat template that you can find in the example folder `here `_. +Instead of passing in a single image, you can pass in a list of images. + +.. code-block:: python + + # Refer to the HuggingFace repo for the correct format to use + prompt = "<|user|>\n\n\nWhat is the content of each image?<|end|>\n<|assistant|>\n" + + # Load the images using PIL.Image + image1 = PIL.Image.open(...) + image2 = PIL.Image.open(...) + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": { + "image": [image1, image2] + }, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +A code example can be found in `examples/offline_inference_vision_language_multi_image.py `_. + +Online Inference +---------------- + +OpenAI Vision API +^^^^^^^^^^^^^^^^^ + +You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. + +Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server. .. code-block:: bash - vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ + --trust-remote-code --limit-mm-per-prompt image=2 .. important:: - We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that - internally for each model. + Since OpenAI Vision API is based on `Chat Completions `_ API, + a chat template is **required** to launch the API server. + + Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it. + The chat template can be inferred based on the documentation on the model's HuggingFace repo. + For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here `_. To consume the server, you can use the OpenAI client like in the example below: .. code-block:: python from openai import OpenAI + openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" + client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) + + # Single-image input inference + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + chat_response = client.chat.completions.create( - model="llava-hf/llava-1.5-7b-hf", + model="microsoft/Phi-3.5-vision-instruct", messages=[{ "role": "user", "content": [ # NOTE: The prompt formatting with the image token `` is not needed # since the prompt will be processed automatically by the API server. - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - }, - }, + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, ], }], ) - print("Chat response:", chat_response) + print("Chat completion output:", chat_response.choices[0].message.content) + + # Multi-image input inference + image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" + image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" + + chat_response = client.chat.completions.create( + model="microsoft/Phi-3.5-vision-instruct", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What are the animals in these images?"}, + {"type": "image_url", "image_url": {"url": image_url_duck}}, + {"type": "image_url", "image_url": {"url": image_url_lion}}, + ], + }], + ) + print("Chat completion output:", chat_response.choices[0].message.content) + A full code example can be found in `examples/openai_vision_api_client.py `_. diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py new file mode 100644 index 0000000000000..73543ab5da2b4 --- /dev/null +++ b/examples/offline_inference_vision_language_multi_image.py @@ -0,0 +1,95 @@ +""" +This example shows how to use vLLM for running offline inference with +multi-image input on vision language models, using the chat template defined +by the model. +""" +from argparse import Namespace +from typing import List + +from vllm import LLM +from vllm.multimodal.utils import fetch_image +from vllm.utils import FlexibleArgumentParser + +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", +] + + +def _load_phi3v(image_urls: List[str]): + return LLM( + model="microsoft/Phi-3.5-vision-instruct", + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + +def run_phi3v_generate(question: str, image_urls: List[str]): + llm = _load_phi3v(image_urls) + + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in image_urls] + }, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def run_phi3v_chat(question: str, image_urls: List[str]): + llm = _load_phi3v(image_urls) + + outputs = llm.chat([{ + "role": + "user", + "content": [ + { + "type": "text", + "text": question, + }, + *({ + "type": "image_url", + "image_url": { + "url": image_url + }, + } for image_url in image_urls), + ], + }]) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def main(args: Namespace): + method = args.method + + if method == "generate": + run_phi3v_generate(QUESTION, IMAGE_URLS) + elif method == "chat": + run_phi3v_chat(QUESTION, IMAGE_URLS) + else: + raise ValueError(f"Invalid method: {method}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models that support multi-image input') + parser.add_argument("--method", + type=str, + default="generate", + choices=["generate", "chat"], + help="The method to run in `vllm.LLM`.") + + args = parser.parse_args() + main(args) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index e1d4055763e5f..1ba702ef019e4 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -27,9 +27,10 @@ models = client.models.list() model = models.data[0].id +# Single-image input inference image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -# Use image url in the payload +## Use image url in the payload chat_completion_from_url = client.chat.completions.create( messages=[{ "role": @@ -52,10 +53,10 @@ ) result = chat_completion_from_url.choices[0].message.content -print(f"Chat completion output:{result}") +print("Chat completion output:", result) -# Use base64 encoded image in the payload +## Use base64 encoded image in the payload def encode_image_base64_from_url(image_url: str) -> str: """Encode an image retrieved from a remote url to base64 format.""" @@ -122,4 +123,4 @@ def encode_image_base64_from_url(image_url: str) -> str: ) result = chat_completion_from_url.choices[0].message.content -print(f"Chat completion output:{result}") +print("Chat completion output:", result) From 8685ba1a1ec08d2c14df924b6e2b499be14405e7 Mon Sep 17 00:00:00 2001 From: "manikandan.tm@zucisystems.com" <94887255+Manikandan-Thangaraj-ZS0321@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:03:37 +0530 Subject: [PATCH 20/77] Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (#7860) --- tests/distributed/test_pipeline_parallel.py | 38 ++++++++------- tests/utils.py | 7 ++- vllm/config.py | 8 ++-- vllm/model_executor/models/internlm2.py | 52 +++++++++++++++------ vllm/model_executor/models/internvl.py | 4 +- vllm/model_executor/models/utils.py | 16 +++++++ 6 files changed, 90 insertions(+), 35 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4d54e43d5788c..637d2b30f6b1f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -18,23 +18,26 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND"), - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - ]) +@pytest.mark.parametrize( + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + ], +) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, + TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + if TRUST_REMOTE_CODE: + pp_args.append("--trust-remote-code") + tp_args.append("--trust-remote-code") pp_env = None if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 and CHUNKED_PREFILL): diff --git a/tests/utils.py b/tests/utils.py index cd8d7b1f25905..04067ef372ac2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,7 +178,12 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = "--trust-remote-code" + if trust_remote_code in arg1 or trust_remote_code in arg2: + tokenizer = AutoTokenizer.from_pretrained(model, + trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(model) prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] diff --git a/vllm/config.py b/vllm/config.py index 9b3f4f9206300..e513608eca9f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -35,18 +35,20 @@ _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ - "AquilaModel", "AquilaForCausalLM", + "AquilaModel", "DeepseekV2ForCausalLM", + "GPT2LMHeadModel", + "InternLM2ForCausalLM", "InternLMForCausalLM", + "InternVLChatModel", "JAISLMHeadModel", "LlamaForCausalLM", "LLaMAForCausalLM", "MistralForCausalLM", - "Phi3ForCausalLM", - "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Phi3ForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 23669b540f561..11a8431a5e7f7 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -8,7 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) @@ -28,6 +28,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class InternLM2MLP(nn.Module): @@ -234,6 +237,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -243,11 +247,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -260,21 +268,31 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None else: - hidden_states = self.tok_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -298,6 +316,8 @@ def __init__( self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -308,7 +328,7 @@ def forward( intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -345,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -353,6 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 5ca8d0b6a2922..d317fdce3ba68 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -341,6 +341,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -461,7 +463,7 @@ def forward( positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 16565e1467e8f..8b80dda96db49 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -12,6 +12,7 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors +from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): + + def make_empty_intermediate_tensors( + batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors From 9da25a88aa35da4b5ad7da545e6189e08c5f52f4 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Thu, 5 Sep 2024 06:48:10 -0600 Subject: [PATCH 21/77] [MODEL] Qwen Multimodal Support (Qwen-VL / Qwen-VL-Chat) (#8029) Signed-off-by: Alex-Brooks Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.rst | 5 + examples/offline_inference_vision_language.py | 15 + tests/models/test_qwen.py | 167 ++++- vllm/entrypoints/chat_utils.py | 2 + vllm/model_executor/layers/resampler.py | 273 +++++++ vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/minicpmv.py | 160 +--- vllm/model_executor/models/qwen.py | 694 +++++++++++++++++- 8 files changed, 1110 insertions(+), 208 deletions(-) create mode 100644 vllm/model_executor/layers/resampler.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 084be1e2a4f8e..0c0a54281e3f3 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -242,6 +242,11 @@ Multimodal Language Models - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`QWenLMHeadModel` + - Qwen + - Image + - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. + - * - :code:`UltravoxModel` - Ultravox - Audio\ :sup:`E+` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 9a0e9d4bc5362..aa1580343aee7 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -159,6 +159,20 @@ def run_blip2(question): return llm, prompt, stop_token_ids +# Qwen +def run_qwen_vl(question): + + llm = LLM( + model="Qwen/Qwen-VL", + trust_remote_code=True, + max_num_seqs=5, + ) + + prompt = f"{question}Picture 1: \n" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -169,6 +183,7 @@ def run_blip2(question): "minicpmv": run_minicpmv, "blip-2": run_blip2, "internvl_chat": run_internvl, + "qwen_vl": run_qwen_vl, } diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 0f974fcc1885c..05f5cbf8c3435 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,48 +1,165 @@ -from typing import Type +import pathlib +from typing import List, Optional, Type import pytest -from ..conftest import HfRunner, VllmRunner +from vllm.multimodal.utils import rescale_image_size + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close -models = ["qwen/qwen-vl"] +pytestmark = pytest.mark.vlm +text_only_models = [ + "Qwen/Qwen-7B-Chat" # Has no visual component +] -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("model", models) -def test_text_only_qwen_model( +multimodal_models = ["Qwen/Qwen-VL"] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "Picture 1: \nWhat's the content of the image?: ", + "cherry_blossom": + "Picture 1: \nWhat is the season?: ", +}) + + +### Tests for multimodal Qwen models +def run_test( + tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - example_prompts, + image_assets: _ImageAssets, model: str, *, + size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): - # This test checks language inputs only, since the visual component - # for qwen-vl is still unsupported in VLLM. In the near-future, the - # implementation and this test will be extended to consider - # visual inputs as well. + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + # Export the images to a tempdir and substitute it into the hf prompt; + # the contents between / will be ignored by VLLM, but the + # transformers implementation for the visual transformer parses this to + # reload it in the forward call; the contents are treated as a URL or a + # local path. + for idx, asset in enumerate(image_assets): + image_tmp_path = tmp_path / f"{asset.name}.jpg" + asset.pil_image.save(image_tmp_path) + HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace( + "", f"{image_tmp_path}") + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + # Qwen encodes images into a fixed content size of 256 + with vllm_runner(model, + max_model_len=300, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, - max_tokens, - num_logprobs=num_logprobs, + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) + +@pytest.mark.parametrize("model", multimodal_models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, + model, size_factors, dtype, max_tokens, + num_logprobs) -> None: + run_test( + tmp_path, + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +# Ensure that a text-only Qwen model can still be loaded and +# used for inference in VLLM without throwing. +@pytest.mark.parametrize("model", text_only_models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_text_only_qwen_model_can_be_loaded_and_run( + vllm_runner: Type[VllmRunner], + example_prompts, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, +): with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( + vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs=num_logprobs, ) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 9a7493649c795..f9f9536a7c160 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -150,6 +150,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): # These models do not use image tokens in the prompt return None + if model_type == "qwen": + return f"Picture {current_count}: " if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 0000000000000..8cd938fc85fb2 --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), Minicpmv2.0 +""" +import math +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import trunc_normal_ + +from vllm.model_executor.layers.linear import ReplicatedLinear + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * + torch.randn(embed_dim, embed_dim)) if do_post_projection else None + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. + """ + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + ) -> None: + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index e30370596496a..4db847029566f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -51,7 +51,6 @@ "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), @@ -88,6 +87,7 @@ "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index dd10729b9ffb5..f8be9490ee55d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -26,11 +26,9 @@ from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, - TypedDict, Union) + TypedDict) -import numpy as np import torch -import torch.nn.functional as F import torch.types from PIL import Image from torch import nn @@ -44,6 +42,8 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import (Resampler2, + get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -98,101 +98,6 @@ class MiniCPMVImagePixelInputs(TypedDict): DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): - # abs_pos: L, C - # tgt_size: (H, W) - # return: M, C - src_size = int(math.sqrt(abs_pos.size(0))) - # tgt_size = int(math.sqrt(tgt_size)) - dtype = abs_pos.dtype - - return (F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size[0], tgt_size[1]), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed( - embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0), -): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - if version == (2, 0): - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) - else: - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: np.ndarray, - version: Tuple[int, int] = (2, 0)): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - - if version == (2, 0): - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - else: - emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: np.ndarray, - version: Tuple[int, int] = (2, 0)): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) / (H, W) - out: (M, D) / (H, W, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - if version == (2, 0): - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - else: - out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product - emb_sin = np.sin(out) # (H, W, D/2) - emb_cos = np.cos(out) # (H, W, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) - return emb - - class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by @@ -245,62 +150,6 @@ def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) -class Resampler2(BaseResampler): - - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - ) -> None: - super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, - norm_layer) - - self.adaptive = adaptive - pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, - grid_size, - version=(2, 0)) - self.pos_embed = nn.Parameter( - torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) - - self.apply(self._init_weights) - - def forward( - self, - x: torch.Tensor, - tgt_sizes: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ): - if self.adaptive: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes, - version=(2, 0)) - pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x, _ = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask, - )[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - class Resampler2_5(BaseResampler): def __init__( @@ -782,7 +631,8 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: num_heads=embed_dim // 128, grid_size=int(math.sqrt(self.config.query_num)), kv_dim=vision_dim, - adaptive=True, + adaptive=False, + do_post_projection=True, ) return resampler diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 8298e3bac4465..a726ec10984c0 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,36 +4,402 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +import math +import re +from array import array +from functools import partial +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, TypedDict, Union) + +import numpy as np import torch +from PIL import Image from torch import nn +from torchvision import transforms +from torchvision.transforms import InterpolationMode from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import print_warning_once +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) + +from .utils import flatten_bn, is_pp_missing_parameter, make_layers + +logger = init_logger(__name__) + +# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; +# for the time being, these tags are not considered as special at encoding +# time. This may change as VLLMs multimodal API changes in the future. +IMG_START = "" +IMG_END = "" +IMG_PAD = "" +# Image context is fixed at 256 for all images +MAX_QWEN_IMG_TOKENS = 256 +# Image normalization params +CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # query/key/value: [sq, b, h] + sq, b, _ = x.size() + mixed_x_layer = self.in_proj(x) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) -from .utils import is_pp_missing_parameter, make_layers + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class QwenVMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = QwenVMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) + + def attention( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock(width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, + image_width // patch_width) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + + self.attn_pool = Resampler2( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, + ) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + def get_image_positions(self, + input_ids: torch.Tensor) -> Optional[torch.Tensor]: + """Given the input IDs, extracts start/stop points corresponding to + images. + + args: + Returns: + Optional torch tensor corresponding to start/stop pairs of images. + """ + if torch.any(input_ids == self.image_start_id): + bos_pos = torch.where(input_ids == self.image_start_id) + eos_pos = torch.where(input_ids == self.image_end_id) + return torch.stack((bos_pos[0], eos_pos[0]), dim=1) + return None class QWenMLP(nn.Module): + """MLP for the language component of the Qwen model, which contains a + MergedColumnParallelLinear merging 2 outputs via silu activation.""" def __init__( self, @@ -56,7 +422,7 @@ def __init__( "Only silu is supported for now.") self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.c_proj(x) @@ -203,6 +569,9 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) if hasattr( + config, "visual") else None def forward( self, @@ -211,9 +580,33 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + pixel_values: Optional[QwenImageInputs], ) -> torch.Tensor: + img_pos = None + # If pixel / visual embeddings are provided, this is a visual model + if pixel_values is not None and self.visual is not None: + if pixel_values["type"] != "image_embeds": + image_embeds = self.visual(pixel_values["data"]) + else: + image_embeds = pixel_values["data"] + + # features should be of shape (# images, 256, hidden_dim) + img_pos = self.visual.get_image_positions(input_ids) + if isinstance( + img_pos, + np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]: + raise ValueError( + f"Number of placeholders: {img_pos.shape[0]} " + f"does not match number of images {image_embeds.shape[0]}." + ) + if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) + # Merge the image embeddings into the hidden states if actually have + # visual features and the corresponding image tokens + if img_pos is not None: + for idx, (img_bos, img_eos) in enumerate(img_pos): + hidden_states[img_bos + 1:img_eos] = image_embeds[idx] residual = None else: assert intermediate_tensors is not None @@ -237,16 +630,241 @@ def forward( return hidden_states -class QWenLMHeadModel(nn.Module): +def get_image_text(image_num: int, padding: bool) -> str: + """Retrieves a placeholder text that when tokenized, will be expanded with + image pads. + + Args: + image_num: The number of the image that we want a text prompt for. + Images should be indexed starting at 1. + padding: Whether or not padding should be manually added. + + Returns: + Text placeholder prompt for the image being considered. + """ + image_start = f"Picture {image_num}: {IMG_START}" + image_end = f"{IMG_END}\n" + if not padding: + return f"{image_start}{image_end}" + return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" + + +def input_processor_for_qwen(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: + """Processes the inputs, which may or may not be multimodal. + Multimodal inputs will only be processed if the model has a "visual" + component in its model config, otherwise they'll be ignored. + + Args: + ctx: Context of the loaded model. + llm_inputs: LLM inputs which may have a multi_modal_data attribute. + + Returns: + If the model is language only or not multimodal inputs were provided, + returns llm_inputs unmodified. Otherwise, processes the multimodal + images / image embeddings and adds the fixed-length image placeholders. + """ + multi_modal_data = llm_inputs.get("multi_modal_data") + + # Only process images if we have multimodal data and a visual config + hf_config = ctx.get_hf_config() + if (multi_modal_data is None or "image" not in multi_modal_data + or not hasattr(hf_config, "visual")): + return llm_inputs + + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + image_data = multi_modal_data["image"] + if isinstance(image_data, torch.Tensor): + num_dims = len(image_data.shape) + if num_dims < 2 or num_dims > 3: + raise ValueError( + f"Expected img embeds to be have 3 dimensions, got {num_dims}") + num_images = 1 if num_dims == 2 else image_data.shape[0] + else: + # TODO - handle multiple image inputs once the API is solidified + num_images = 1 + + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + new_prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + + if num_matched_images != num_images: + logger.warning( + "Number of matched image placeholders %s doesn't match the number " + "of expected images %s; check your placeholder formatting.", + num_matched_images, num_images) + + new_prompt_token_ids = tokenizer.encode(new_prompt) + + return LLMInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + + +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ + # Early exit if we have provided an image to a language only Qwen model + hf_config = ctx.get_hf_config() + if not hasattr(hf_config, "visual"): + logger.warning( + "Images were provided but this model has no visual config; " + "multimodal inputs will not be forwarded to the model.") + return MultiModalInputs() + + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + image_pair_tok = tokenizer.encode(IMG_START + IMG_END, + add_special_tokens=False, + return_tensors="pt").squeeze() + image_start_id = image_pair_tok[0] + image_end_id = image_pair_tok[-1] + if (image_start_id + 1) != image_end_id: + raise ValueError( + f"Found image end ID {image_end_id}, but expected {IMG_START} + 1") + if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2): + raise ValueError( + f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, " + f"but got {image_pair_tok - 2}") + + hf_config = ctx.get_hf_config() + image_size = hf_config.visual["image_size"] + img_emb_size = hf_config.visual["output_dim"] + + if isinstance(data, torch.Tensor): + # It's expected that our values have already been processed + # by the visual transformer; shape is expected to be: + # (# images, 256, hidden_size) + if len(data.shape) == 2: + # Assume only one image embed was provided; unsqueeze the extra dim + data = data.unsqueeze(0) + if len(data.shape) != 3 or data.shape[ + 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: + raise ValueError( + "Expected image embeds to be a tensor of shape" + f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " + f"received shape [{data.shape}]") + pixel_values = data + + else: + transform = build_normalization_transform(image_size) + # TODO - handle multiple image inputs once the API is solidified + transformed_images = [transform(data)] + pixel_values = torch.stack(transformed_images, dim=0) + return MultiModalInputs({"pixel_values": pixel_values}) + + +def build_normalization_transform(image_size: int) -> transforms.Compose: + """Builds a normalization transform which can be applied to one or + more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. + """ + return transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + + +def dummy_data_for_qwen( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +) -> Tuple[SequenceData, Optional[Dict]]: + """Build dummy data for warming up Qwen models; this will only contain text + matching the defaults for VLLM unless the model has a visual config. + + Args: + ctx: Context of the loaded model. + seq_len: Number of tokens in the text sequence. + mm_counts: multimodal data counts. + + Returns: + Tuple containing sequential and multimodal data. + """ + hf_config = ctx.get_hf_config() + + # The presence of a visual config indicates this is a multimodal model. + # If we don't have it, the model is considered an LLM for warmup purposes. + if not hasattr(hf_config, "visual"): + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)) + mm_data = None + return seq_data, mm_data + + # We have a visual component - use images to warm up + num_images = mm_counts["image"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + # Build the image prompts with no imgpads; the tokenizer will add img pads + image_prompt = ''.join( + [get_image_text(idx, False) for idx in range(1, num_images + 1)]) + toks = tokenizer.encode(image_prompt, add_special_tokens=False) + + # Make sure we actually get the fixed context size per tok padding + num_pads = toks.count(tokenizer.encode(IMG_PAD)[0]) + if num_pads != (num_images * MAX_QWEN_IMG_TOKENS): + raise ValueError( + f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads" + f" per image, but got {num_pads} pads for {num_images} image(s)" + " in total. Are you using a qwen tokenizer?") + + # Ensure the number of tokens is at minimum the sequence length provided + if len(toks) < seq_len: + toks += [0] * (seq_len - len(toks)) + + # Build the input images; width/height doesn't actually matter here since + # the data will get resized and the # of tokens per image is constant + image = Image.new("RGB", (224, 224), color=0) + mm_data = {"image": image if num_images == 1 else [image] * num_images} + return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, + multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config + self.multimodal_config = multimodal_config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, @@ -257,16 +875,47 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + def _get_image_input_type( + self, + pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + """Determines if the provided pixel_values are normalized pixel values + or image embeddings. + + Args: + pixel_values: Optional data to processed into visual embeddings. + + Returns: + None of the QwenImageInputs type used to determine whether or not + the visual transformer needs to process the pixel_values. + """ + if pixel_values is not None and self.transformer.visual is not None: + pixel_values = flatten_bn(pixel_values) + if len(pixel_values.shape) == 3 and pixel_values.shape[ + 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ + 2] == self.config.visual["output_dim"]: + return QwenImageEmbeddingInputs( + type="image_embeds", + data=pixel_values, + ) + else: + # If we have the wrong shape, assume we still need to process + return QwenImagePixelInputs( + type="pixel_values", + data=pixel_values, + ) + return None + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: + pixel_values = self._get_image_input_type(pixel_values) hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + pixel_values) return hidden_states def make_empty_intermediate_tensors( @@ -328,15 +977,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip loading visual weights to support Qwen-VL models - # in cases with text-only inputs - # TODO: add support for Qwen-VL - if (name not in params_dict - and name.startswith("transformer.visual.")): - print_warning_once( - "Only text inputs are allowed. Images won't be handled " - "until Qwen-VL models are fully supported.") - continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From 2ee45281a5012072f41573eb09e1f82985adc761 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 5 Sep 2024 11:09:46 -0400 Subject: [PATCH 22/77] Move verify_marlin_supported to GPTQMarlinLinearMethod (#8165) --- vllm/model_executor/layers/quantization/gptq_marlin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..b06ff7bd2bace 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -51,10 +51,6 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " @@ -153,6 +149,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config + # Verify supported on platform. + verify_marlin_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + def create_weights( self, layer: torch.nn.Module, From 2febcf2777c77de576ceb5c39cba1dbc2033d04d Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:25:29 -0700 Subject: [PATCH 23/77] [Documentation][Spec Decode] Add documentation about lossless guarantees in Speculative Decoding in vLLM (#7962) --- docs/source/models/spec_decode.rst | 40 ++++++++++++++++++++++++++++++ docs/source/serving/faq.rst | 19 ++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst index d3c196faff25d..50468f25b922a 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/models/spec_decode.rst @@ -161,6 +161,46 @@ A variety of speculative models of this type are available on HF hub: * `granite-7b-instruct-accelerator `_ * `granite-20b-code-instruct-accelerator `_ +Lossless guarantees of Speculative Decoding +------------------------------------------- +In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of +speculative decoding, breaking down the guarantees into three key areas: + +1. **Theoretical Losslessness** + - Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might + cause slight variations in output distributions, as discussed + in `Accelerating Large Language Model Decoding with Speculative Sampling `_ + +2. **Algorithmic Losslessness** + - vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include: + + - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target + distribution. `View Test Code `_ + + - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling + without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, + provides a lossless guarantee. Almost all of the tests in `this directory `_ + verify this property using `this assertion implementation `_ + +3. **vLLM Logprob Stability** + - vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the + same request across runs. For more details, see the FAQ section + titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_. + + +**Conclusion** + +While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding +can occur due to following factors: + +- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution. + +- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially + due to non-deterministic behavior in batched operations or numerical instability. + +**Mitigation Strategies** + +For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_. Resources for vLLM contributors ------------------------------- diff --git a/docs/source/serving/faq.rst b/docs/source/serving/faq.rst index 7b0374be8adff..9e858e612c8bf 100644 --- a/docs/source/serving/faq.rst +++ b/docs/source/serving/faq.rst @@ -10,3 +10,22 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul Q: Which model to use for offline inference embedding? A: If you want to use an embedding model, try: https://huggingface.co/intfloat/e5-mistral-7b-instruct. Instead models, such as Llama-3-8b, Mistral-7B-Instruct-v0.3, are generation models rather than an embedding model + +---------------------------------------- + + Q: Can the output of a prompt vary across runs in vLLM? + +A: Yes, it can. vLLM does not guarantee stable log probabilities (logprobs) for the output tokens. Variations in logprobs may occur due to +numerical instability in Torch operations or non-deterministic behavior in batched Torch operations when batching changes. For more details, +see the `Numerical Accuracy section `_. + +In vLLM, the same requests might be batched differently due to factors such as other concurrent requests, +changes in batch size, or batch expansion in speculative decoding. These batching variations, combined with numerical instability of Torch operations, +can lead to slightly different logit/logprob values at each step. Such differences can accumulate, potentially resulting in +different tokens being sampled. Once a different token is sampled, further divergence is likely. + +**Mitigation Strategies** + +- For improved stability and reduced variance, use `float32`. Note that this will require more memory. +- If using `bfloat16`, switching to `float16` can also help. +- Using request seeds can aid in achieving more stable generation for temperature > 0, but discrepancies due to precision differences may still occur. From 9f97b3b08a9f29df4518c3141cb43f807ca89911 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 5 Sep 2024 21:07:45 +0000 Subject: [PATCH 24/77] update/fix weight loading to support tp --- vllm/model_executor/layers/fused_moe/layer.py | 80 ++++++++++--------- .../layers/quantization/gptq_marlin.py | 11 ++- 2 files changed, 53 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b0d7d4b538df3..f4621e5c4ccc4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -306,10 +306,28 @@ def _load_single_value(self, param: torch.nn.Parameter, # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # llm-compressor returns weights on disk which are flipped + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsMoEMethod") else loaded_weight + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") @@ -325,38 +343,41 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # is_transposed: whether or not the parameter is transposed on disk - # If transposed, the loaded weight will be transposed and the dim - # to shard the loaded weight will be flipped. + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim - # GPTQ Values - if ("scales" in weight_name or "qweight" in weight_name - or "qzeros" in weight_name): - if (shard_id == "w1" or shard_id == "w3"): - shard_dim = 1 - shard_dim - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - return + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") - if "g_idx" in weight_name: self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return - # Case weight_scales - if "weight_scale" in weight_name: - # load the weight scaling based on the quantization scheme - # supported weight scales can be found in + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + # Case weight scales and zero_points + if ("scale" in weight_name or "zero" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case @@ -385,22 +406,9 @@ def weight_loader(self, param: torch.nn.Parameter, f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return + # Case weight_shape if "weight_shape" in weight_name: - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return - - # Case input scale - if "input_scale" in weight_name: - # Note: input_scale loading is only supported for fp8 - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - + # only required by compressed-tensors self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 11012a326b045..c3b9adb1d1982 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -7,8 +7,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -372,9 +372,16 @@ def create_weights( if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size scales_size2 = intermediate_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( torch.empty( From db3bf7c991cd1a0297d1a8ba501e59cfa226c337 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Thu, 5 Sep 2024 18:10:33 -0700 Subject: [PATCH 25/77] [Core] Support load and unload LoRA in api server (#6566) Co-authored-by: Jee Jee Li --- docs/requirements-docs.txt | 1 - docs/source/models/lora.rst | 52 +++++++++ .../llm/test_generate_multiple_loras.py | 2 +- .../entrypoints/openai/test_serving_engine.py | 107 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 40 ++++++- vllm/entrypoints/openai/protocol.py | 10 ++ vllm/entrypoints/openai/serving_engine.py | 79 ++++++++++++- vllm/envs.py | 7 ++ vllm/lora/request.py | 19 +++- vllm/utils.py | 25 ++++ 10 files changed, 336 insertions(+), 6 deletions(-) create mode 100644 tests/entrypoints/openai/test_serving_engine.py diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index c358e23b6a37a..6687929c0bebe 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -11,6 +11,5 @@ pydantic >= 2.8 torch py-cpuinfo transformers -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args mistral_common >= 1.3.4 openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index f08773fe59d92..b3821ebdfceca 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -107,3 +107,55 @@ The following is an example request "max_tokens": 7, "temperature": 0 }' | jq + + +Dynamically serving LoRA Adapters +--------------------------------- + +In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading +LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility +to change models on-the-fly is needed. + +Note: Enabling this feature in production environments is risky as user may participate model adapter management. + +To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` +is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active. + +.. code-block:: bash + + export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True + + +Loading a LoRA Adapter: + +To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary +details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter. + +Example request to load a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/load_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter", + "lora_path": "/path/to/sql-lora-adapter" + }' + +Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter +cannot be found or loaded, an appropriate error message will be returned. + +Unloading a LoRA Adapter: + +To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint +with the name or ID of the adapter to be unloaded. + +Example request to unload a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/unload_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter" + }' diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 35eabf079964a..9f5727ecd0406 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -50,7 +50,7 @@ def zephyr_lora_files(): @pytest.mark.skip_global_cleanup def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): lora_request = [ - LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files) + LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) for idx in range(len(PROMPTS)) ] # Multiple SamplingParams should be matched with each prompt diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py new file mode 100644 index 0000000000000..325bc03434287 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -0,0 +1,107 @@ +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from vllm.config import ModelConfig +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.serving_engine import OpenAIServing + +MODEL_NAME = "meta-llama/Llama-2-7b" +LORA_LOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_UNLOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' removed successfully.") + + +async def _async_serving_engine_init(): + mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_model_config = MagicMock(spec=ModelConfig) + # Set the max_model_len attribute to avoid missing attribute + mock_model_config.max_model_len = 2048 + + serving_engine = OpenAIServing(mock_engine_client, + mock_model_config, + served_model_names=[MODEL_NAME], + lora_modules=None, + prompt_adapters=None, + request_logger=None) + return serving_engine + + +@pytest.mark.asyncio +async def test_load_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter", + lora_path="/path/to/adapter2") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert len(serving_engine.lora_requests) == 1 + assert serving_engine.lora_requests[0].lora_name == "adapter" + + +@pytest.mark.asyncio +async def test_load_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="", lora_path="") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_load_lora_adapter_duplicate(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 1 + + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + assert len(serving_engine.lora_requests) == 1 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert len(serving_engine.lora_requests) == 1 + + request = UnloadLoraAdapterRequest(lora_name="adapter1") + response = await serving_engine.unload_lora_adapter(request) + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 0 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_not_found(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 728a2e5232d9b..d8704d5e24964 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -35,11 +35,13 @@ DetokenizeResponse, EmbeddingRequest, EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, TokenizeRequest, - TokenizeResponse) -# yapf: enable + TokenizeResponse, + UnloadLoraAdapterRequest) from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -343,6 +345,40 @@ async def stop_profile(): return Response(status_code=200) +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest): + response = await openai_serving_chat.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest): + response = await openai_serving_chat.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(router) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ff9c3690672b6..970262a4bd358 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -878,3 +878,13 @@ class DetokenizeRequest(OpenAIBaseModel): class DetokenizeResponse(OpenAIBaseModel): prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 26e91e7cc94dd..ac74527441cd9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,11 +16,13 @@ CompletionRequest, DetokenizeRequest, EmbeddingRequest, ErrorResponse, + LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, TokenizeChatRequest, TokenizeCompletionRequest, - TokenizeRequest) + TokenizeRequest, + UnloadLoraAdapterRequest) # yapf: enable from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -32,6 +34,7 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AtomicCounter logger = init_logger(__name__) @@ -78,6 +81,7 @@ def __init__( self.served_model_names = served_model_names + self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ @@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token return tokenizer.decode(token_id) + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return self.create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return self.create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." diff --git a/vllm/envs.py b/vllm/envs.py index 3c6b6adff82fc..ed45047e9f8fc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -61,6 +61,7 @@ VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False def get_default_cache_root(): @@ -409,6 +410,12 @@ def get_default_config_root(): # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition diff --git a/vllm/lora/request.py b/vllm/lora/request.py index d770da4f2407d..47a59d80d3a45 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,7 +28,6 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None - __hash__ = AdapterRequest.__hash__ def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: @@ -75,3 +74,21 @@ def local_path(self, value): DeprecationWarning, stacklevel=2) self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/utils.py b/vllm/utils.py index 657a3ecef696d..a22081ebe8df0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1224,3 +1224,28 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def supports_dynamo() -> bool: base_torch_version = Version(Version(torch.__version__).base_version) return base_torch_version >= Version("2.4.0") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value From baa5467547a758af35f442af6edfbc0fb73c83ce Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 5 Sep 2024 20:39:29 -0700 Subject: [PATCH 26/77] [BugFix] Fix Granite model configuration (#8216) --- vllm/transformers_utils/config.py | 62 +++++++++++++-------- vllm/transformers_utils/configs/__init__.py | 4 ++ 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index dfe83ddb731d4..4f4e79d10a677 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,12 +10,16 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - InternVLChatConfig, JAISConfig, - MedusaConfig, MLPSpeculatorConfig, - MPTConfig, NemotronConfig, - RWConfig, UltravoxConfig) + GraniteConfig, InternVLChatConfig, + JAISConfig, MedusaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, RWConfig, + UltravoxConfig) +# yapf: enable from vllm.transformers_utils.utils import check_gguf_file if VLLM_USE_MODELSCOPE: @@ -39,6 +43,9 @@ "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, + # Granite can be removed from here once we have upgraded to + # transformers 4.45+ + "granite": GraniteConfig, } for name, cls in _CONFIG_REGISTRY.items(): @@ -62,29 +69,36 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs) - except ValueError as e: - if (not trust_remote_code and - "requires you to execute the configuration file" in str(e)): - err_msg = ( - "Failed to load the model config. If the model is a custom " - "model not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - if config.model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[config.model_type] + config_dict, _ = PretrainedConfig.get_config_dict( + model, revision=revision, code_revision=code_revision, **kwargs) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model is a custom " + "model not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e # Special architecture mapping check for GGUF models if is_gguf: diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 736878b35ad49..8381c5227584e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -27,4 +28,7 @@ "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", + # Granite can be removed from here once we have upgraded to + # transformers 4.45+ + "GraniteConfig", ] From b841ac498f3441e0532a456c1872b24051072696 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Sep 2024 03:12:10 -0400 Subject: [PATCH 27/77] remove 8-bit stuff for now --- csrc/moe/marlin_moe_ops.cu | 303 ++++++------------ csrc/moe/marlin_moe_ops.h | 7 +- csrc/moe/torch_bindings.cpp | 8 +- tests/kernels/test_moe.py | 14 +- vllm/_custom_ops.py | 2 +- .../layers/fused_moe/__init__.py | 8 +- .../layers/fused_moe/fused_marlin_moe.py | 52 +-- .../compressed_tensors_moe.py | 1 - .../schemes/compressed_tensors_wNa16.py | 1 - .../layers/quantization/gptq_marlin.py | 1 - 10 files changed, 120 insertions(+), 277 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index f6d475a56851f..92184f43c9eb0 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,8 +25,6 @@ #include -#include "core/scalar_type.hpp" - template inline std::string str(T x) { return std::to_string(x); @@ -133,26 +131,11 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -173,28 +156,6 @@ __device__ inline FragB dequant(int q) { return frag_b; } -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -335,8 +296,7 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; bool is_same_group[stages]; @@ -893,19 +840,10 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); + FragB frag_b0 = dequant(b_quant); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -917,6 +855,8 @@ __device__ inline void MarlinMoESingle( } } + FragB frag_b1 = dequant(b_quant_shift); + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -941,13 +881,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; + constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1095,10 +1035,8 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { res = __hmul2(res, s[0]); } @@ -1228,70 +1166,28 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { + if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { + if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } } } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1331,8 +1227,7 @@ __device__ inline void MarlinMoESingle( } } -template ( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1447,8 +1342,7 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -1601,43 +1494,42 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1719,13 +1611,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } - int pack_factor = 32 / q_type.size_bits(); - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = - (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1756,14 +1645,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1785,15 +1670,9 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - - int pack_factor = 32 / b_q_type->size_bits(); - int max_par = 4; int dev = a.get_device(); @@ -1854,8 +1733,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, - topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + has_act_order, is_k_full, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index adee8399a4d6f..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,14 +2,11 @@ #include -#include "core/scalar_type.hpp" - torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cd65a8ee92b94..8a0e625b43fa1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,11 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " + "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " + "bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7e359ff08088c..2250cf1598b8b 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,7 +140,6 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -149,7 +148,6 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, - num_bits: int, ): torch.manual_seed(7) @@ -163,8 +161,7 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = scalar_types.uint4b8 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -243,7 +240,6 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, - num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -258,7 +254,6 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_marlin_moe_mmm( m: int, n: int, @@ -267,7 +262,6 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, - num_bits: int, ): if topk > e: return @@ -279,8 +273,7 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = scalar_types.uint4b8 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -315,8 +308,7 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False, - num_bits=num_bits) + renormalize=False) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 51db8b34e2914..fe254732e7309 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -314,7 +314,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + output = torch.empty((num_experts, size_k // 16, size_n * 2), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index dea4a32aec4f8..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,3 @@ -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -8,16 +6,18 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "fused_marlin_moe", - "single_marlin_moe", ] if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c7906205760ff..6b01ec0a623aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,21 +7,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - num_bits: int = 8, -) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: """ This function computes a Marlin MoE MMM using weights w and top-k gating mechanism. It is meant for testing and debugging. @@ -38,7 +35,6 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -50,14 +46,11 @@ def single_marlin_moe( assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert num_bits in [4, 8] + assert hidden_states.dtype == torch.float16 M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // (num_bits // 2) + N = w.shape[2] // 2 topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -82,13 +75,10 @@ def single_marlin_moe( device="cuda", requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, - block_size_m, True, False) + g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, + False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -107,7 +97,6 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -132,7 +121,6 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -142,16 +130,13 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // ( - num_bits // 2), "Hidden size mismatch w2" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert num_bits in [4, 8] + assert hidden_states.dtype == torch.float16 M, K = hidden_states.shape E = w1.shape[0] @@ -179,9 +164,6 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -198,7 +180,6 @@ def fused_marlin_moe( g_idx1, perm1, workspace, - scalar_type, M, 2 * N, K, @@ -222,7 +203,6 @@ def fused_marlin_moe( g_idx2, perm2, workspace, - scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7dee2fca81153..f8a41dfd08d73 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -295,5 +295,4 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..e3b74e8712903 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -18,7 +18,6 @@ __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 11012a326b045..d114d52812849 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -600,5 +600,4 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, - num_bits=self.quant_config.quant_type.size_bits, ) From 9d8a80cc9c07a4361279c3f890bbfcea65c33df7 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 15:13:54 +0000 Subject: [PATCH 28/77] fix; update large model testing cases --- .buildkite/test-pipeline.yaml | 13 ++++++++++++- tests/weight_loading/models-large.txt | 3 +++ tests/weight_loading/models.txt | 2 -- .../compressed_tensors/compressed_tensors_moe.py | 7 ++----- .../schemes/compressed_tensors_wNa16.py | 1 + 5 files changed, 18 insertions(+), 8 deletions(-) create mode 100644 tests/weight_loading/models-large.txt diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 86eddb576c42a..900dc72e74466 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -375,7 +375,18 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt new file mode 100644 index 0000000000000..fe76705746766 --- /dev/null +++ b/tests/weight_loading/models-large.txt @@ -0,0 +1,3 @@ +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98e..a3e382acf56b3 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f8a41dfd08d73..49c29c2775cb6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,8 +6,6 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -40,11 +38,10 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): + and self.num_bits == 4): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + "is supported for 4 bits") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index e3b74e8712903..cae6ffad53df1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -18,6 +18,7 @@ __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) From e5cab71531360345e5b30b98dfcfec8087d6cddf Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:01:14 -0400 Subject: [PATCH 29/77] [Frontend] Add --logprobs argument to `benchmark_serving.py` (#8191) --- benchmarks/backend_request_func.py | 2 ++ benchmarks/benchmark_serving.py | 16 ++++++++++++++++ tests/multi_step/test_correctness_llm.py | 2 +- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index f7d67692f697b..3243bb94f787c 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -24,6 +24,7 @@ class RequestFuncInput: model: str best_of: int = 1 use_beam_search: bool = False + logprobs: Optional[int] = None @dataclass @@ -236,6 +237,7 @@ async def async_request_openai_completions( "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, "stream": True, } headers = { diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 84f366bdba387..bdfa81be4208e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -318,6 +318,7 @@ async def benchmark( model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], + logprobs: Optional[int], best_of: int, use_beam_search: bool, request_rate: float, @@ -339,6 +340,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -358,6 +360,7 @@ async def benchmark( api_url=base_url + "/start_profile", prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -379,6 +382,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -396,6 +400,7 @@ async def benchmark( api_url=base_url + "/stop_profile", prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -580,6 +585,7 @@ def main(args: argparse.Namespace): model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, + logprobs=args.logprobs, best_of=args.best_of, use_beam_search=args.use_beam_search, request_rate=args.request_rate, @@ -721,6 +727,16 @@ def main(args: argparse.Namespace): help= "Number of output tokens per request, used only for sonnet dataset.", ) + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) parser.add_argument( "--sonnet-prefix-len", type=int, diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 50c85df932e25..24ebb60a9cbfd 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -57,7 +57,7 @@ def test_multi_step_llm( GPU -> CPU output transfer num_prompts: number of example prompts under test num_logprobs: corresponds to the `logprobs` argument to the OpenAI - completions endpoint; `None` -> no logprobs + completions endpoint; `None` -> 1 logprob returned. """ prompts = example_prompts From 315e22f7f86ad7f213a266b10938c5876587b61a Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 16:13:55 +0000 Subject: [PATCH 30/77] add hack to support unfused mixtral pathway for int8 --- vllm/model_executor/model_loader/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index d247e4cf3f07b..0052489d99dc4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,19 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors"] + # for gptq_marlin, only run fused MoE for int4 + if model_config.quantization == "gptq_marlin": + hf_quant_config = getattr(model_config.hf_config, + "quantization_config", None) + if hf_quant_config and hf_quant_config.get("bits") == 4: + mixtral_supported.append("gptq_marlin") + if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + return ModelRegistry.resolve_model_cls(architectures) From de80783b6907eb084493a76ef9ec3e3941cc2087 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 6 Sep 2024 09:18:35 -0700 Subject: [PATCH 31/77] [Misc] Use ray[adag] dependency instead of cuda (#7938) --- Dockerfile | 2 -- MANIFEST.in | 1 - requirements-adag.txt | 3 --- requirements-test.txt | 5 +---- vllm/executor/ray_gpu_executor.py | 20 ++++++++++++++++++-- 5 files changed, 19 insertions(+), 12 deletions(-) delete mode 100644 requirements-adag.txt diff --git a/Dockerfile b/Dockerfile index 7f255e1d6e93e..2375e3f4d7387 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,6 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt -COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt @@ -66,7 +65,6 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt -COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm diff --git a/MANIFEST.in b/MANIFEST.in index 5a41e5e714184..82be639ef4d73 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include LICENSE -include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/requirements-adag.txt b/requirements-adag.txt deleted file mode 100644 index e77f90fb8f85d..0000000000000 --- a/requirements-adag.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Dependencies for Ray accelerated DAG -cupy-cuda12x -ray >= 2.32 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 58cf1716b45ce..44ba99fe84bd4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,3 @@ -# Needed for Ray accelerated DAG tests --r requirements-adag.txt - # testing pytest tensorizer>=2.9.0 @@ -16,7 +13,7 @@ httpx librosa # required for audio test peft requests -ray +ray[adag]>=2.35 sentence-transformers # required for embedding soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index ab8844bcdafec..1359a0d310a70 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -427,18 +427,34 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self, enable_asyncio: bool): + def _check_ray_adag_installation(self): import pkg_resources from packaging import version - required_version = version.parse("2.32") + required_version = version.parse("2.35") current_version = version.parse( pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") + import importlib.util + adag_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if adag_spec is None: + raise ValueError("Ray accelerated DAG is not installed. " + "Run `pip install ray[adag]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "Run `pip install ray[adag]` and check cupy installation.") + + def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray + self._check_ray_adag_installation() from ray.dag import InputNode, MultiOutputNode from ray.experimental.channel.torch_tensor_type import TorchTensorType From 565cc4334d7ad9a2bc9d87cb8b0ae6db189eb1a9 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 18:29:36 +0000 Subject: [PATCH 32/77] fix install for tpu test --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 72e4149e31287..a73c462c148c2 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,8 +5,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -583,6 +581,8 @@ def apply( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, From 1447c97e753919709b613590d7267c93d07d9382 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Fri, 6 Sep 2024 14:51:03 -0400 Subject: [PATCH 33/77] [CI/Build] Increasing timeout for multiproc worker tests (#8203) --- tests/engine/test_multiproc_workers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index 610ad9732fb91..e07dd6deef5bf 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -83,7 +83,7 @@ def execute_workers(worker_input: str) -> None: workers[3].process.kill() # Other workers should get shut down here - worker_monitor.join(2) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() @@ -108,7 +108,7 @@ def test_local_workers_clean_shutdown() -> None: # Clean shutdown worker_monitor.close() - worker_monitor.join(5) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() @@ -161,7 +161,7 @@ async def execute_workers(worker_input: str) -> None: workers[3].process.kill() # Other workers should get shut down here - worker_monitor.join(2) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() From 9db52eab3dc0b7b2cf30fa4399d569131e90c2d4 Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 6 Sep 2024 17:26:09 -0500 Subject: [PATCH 34/77] [Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (#8248) --- .../layers/quantization/awq_triton.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index ad706f28a742b..d0b210c3a2747 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -22,7 +22,7 @@ def awq_dequantize_kernel( # Compute offsets and masks for qweight_ptr. offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) - offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] masks_y = offsets_y < num_rows @@ -43,6 +43,9 @@ def awq_dequantize_kernel( # Load the weights. iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. @@ -59,9 +62,8 @@ def awq_dequantize_kernel( iweights = (iweights >> shifts) & 0xF # Compute zero offsets and masks. - zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + - tl.arange(0, BLOCK_SIZE_Y) // group_size) - zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8 + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] zero_masks_y = zero_offsets_y < num_rows // group_size @@ -70,13 +72,16 @@ def awq_dequantize_kernel( # Load the zeros. zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Unpack and reorder: shift out the correct 4-bit value and mask. zeros = (zeros >> shifts) & 0xF # Compute scale offsets and masks. - scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size + - tl.arange(0, BLOCK_SIZE_Y) // group_size) + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)) scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + @@ -87,6 +92,7 @@ def awq_dequantize_kernel( # Load the scales. scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Dequantize. iweights = (iweights - zeros) * scales @@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) masks_am = offsets_am < M - offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) + - tl.arange(0, BLOCK_SIZE_N) // 8) + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) masks_bn = offsets_bn < N // 8 - offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) + - tl.arange(0, BLOCK_SIZE_N) // 8) + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) masks_zn = offsets_zn < N // 8 offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_b = masks_k[:, None] & masks_bn[None, :] b = tl.load(b_ptrs, mask=masks_b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) # Dequantize b. offsets_szk = ( (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + - tl.arange(0, BLOCK_SIZE_K) // group_size) + tl.arange(0, 1)) offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF From 23f322297f33a50dd1fe0870665d0c4414fd78ab Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 6 Sep 2024 18:29:03 -0400 Subject: [PATCH 35/77] [Misc] Remove `SqueezeLLM` (#8220) --- CMakeLists.txt | 1 - csrc/ops.h | 3 - .../squeezellm/quant_cuda_kernel.cu | 216 ------------------ csrc/torch_bindings.cpp | 6 - .../quantization/supported_hardware.rst | 11 - examples/fp8/README.md | 4 +- vllm/_custom_ops.py | 6 - vllm/config.py | 4 +- vllm/entrypoints/llm.py | 2 +- vllm/lora/layers.py | 2 +- .../layers/quantization/__init__.py | 2 - .../layers/quantization/squeezellm.py | 138 ----------- 12 files changed, 6 insertions(+), 389 deletions(-) delete mode 100644 csrc/quantization/squeezellm/quant_cuda_kernel.cu delete mode 100644 vllm/model_executor/layers/quantization/squeezellm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 923ed084ffd9e..9c88c31c83da1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,6 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" - "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 8d24545de898d..45a3868395d12 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales); -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); - torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu deleted file mode 100644 index 8ed918b3d7c27..0000000000000 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ /dev/null @@ -1,216 +0,0 @@ -#include -#include -#include -#include - -// half-tensor -#include -#include -#include - -#define BLOCKWIDTH 128 -#define BLOCKHEIGHT4 16 - -namespace vllm { -namespace squeezellm { - -__device__ inline unsigned int as_unsigned(int i) { - return *reinterpret_cast(&i); -} - -// 4-bit matvec kernel (LUT-based) -__global__ void NUQ4MatMulKernel( -#ifndef USE_ROCM - const half2* __restrict__ vec, -#else - const __half2* __restrict__ vec, -#endif - const int* __restrict__ mat, -#ifndef USE_ROCM - half2* __restrict__ mul, -#else - float2* __restrict__ mul, -#endif - const __half* __restrict__ lookup_table, int height, int width, int batch, - int vec_height) { - - const int blockwidth2 = BLOCKWIDTH / 2; - - int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; - -#ifndef USE_ROCM - __shared__ half2 blockvec[blockwidth2]; -#else - __shared__ __half2 blockvec[blockwidth2]; -#endif - - __shared__ __half deq2[16][BLOCKWIDTH]; - int off = threadIdx.x; - int column_offset = col * 16; - for (int val = 0; val < 16; val += 1) { - int lut_index = column_offset + val; - deq2[val][off] = lookup_table[lut_index]; - } - - __half res; -#ifndef USE_ROCM - half2 res2; - half2 tmp2; -#else - __half2 res2; - __half2 tmp2; -#endif - - int i; - int k; - - unsigned int tmp1; - unsigned int lut_index1, lut_index2; - - for (int b = 0; b < batch; ++b) { - i = width * row + col; - res = __int2half_rd(0); - k = 0; - - __syncthreads(); - if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = - vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + - threadIdx.x]; - __syncthreads(); - - while (k < blockwidth2) { - tmp1 = as_unsigned(mat[i]); - -#ifndef USE_ROCM - res2 = {}; - tmp2 = {}; -#else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); - tmp2.x = __half_as_ushort(__float2half(0)); - tmp2.y = __half_as_ushort(__float2half(0)); -#endif - - lut_index1 = tmp1 & 0xF; - lut_index2 = (tmp1 >> 4) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 0], res2); - - lut_index1 = (tmp1 >> 8) & 0xF; - lut_index2 = (tmp1 >> 12) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 1], res2); - - lut_index1 = (tmp1 >> 16) & 0xF; - lut_index2 = (tmp1 >> 20) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 2], res2); - - lut_index1 = (tmp1 >> 24) & 0xF; - lut_index2 = (tmp1 >> 28) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 3], res2); - -#ifndef USE_ROCM - res = __hadd(__hadd(res2.x, res2.y), res); -#else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), - res); -#endif - - i += width; - k += 4; - } - - // col%2 -> only set one of the two values -#ifndef USE_ROCM - half2 res3 = {}; - if (col % 2 == 0) { - res3.x = res; - } else { - res3.y = res; - } -#else - __half2 res3; - res3.x = __half_as_ushort(__float2half(0)); - res3.y = __half_as_ushort(__float2half(0)); - if (col % 2 == 0) { - res3.x = __half_as_ushort(res); - } else { - res3.y = __half_as_ushort(res); - } -#endif - -#ifndef USE_ROCM - atomicAdd(&mul[b * width / 2 + col / 2], res3); -#else - int tmp_addr = b * width / 2 + col / 2; - atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); - atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); -#endif - } -} - -} // namespace squeezellm -} // namespace vllm - -// 4-bit matvec kernel (LUT-based) -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table) { - int height = mat.size(0); - int width = mat.size(1); - - int batch = vec.size(0); - int vec_height = vec.size(1); - - dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH); - dim3 threads(BLOCKWIDTH); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::squeezellm::NUQ4MatMulKernel<<>>( -#ifndef USE_ROCM - (half2*)vec.data_ptr(), -#else - (__half2*)vec.data_ptr(), -#endif - mat.data_ptr(), -#ifndef USE_ROCM - (half2*)mul.data_ptr(), - (__half*)lookup_table.data_ptr(), -#else - (float2*)mul.data_ptr(), - (__half*)lookup_table.data_ptr(), -#endif - height, width, batch, vec_height); -} - -#undef BLOCKWIDTH -#undef BLOCKHEIGHT4 diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7783acd741f5f..07b14e7a6ff63 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); - // Quantized GEMM for SqueezeLLM. - ops.def( - "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " - "lookup_table) -> ()"); - ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); - // Compute FP8 quantized tensor for given scaling factor. ops.def( "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index 6341b583f0cfe..ea587e0525a74 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations - ✗ - ✗ - ✗ - * - SqueezeLLM - - ✅︎ - - ✅︎ - - ✅︎ - - ✅︎ - - ✅︎ - - ✗ - - ✗ - - ✗ - - ✗ - - ✗ Notes: ^^^^^^ diff --git a/examples/fp8/README.md b/examples/fp8/README.md index 84ad76c71862e..181c36558fcff 100644 --- a/examples/fp8/README.md +++ b/examples/fp8/README.md @@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various python3 benchmarks/benchmark_throughput.py --help usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] - [--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] + [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] [--quantization-param-path KV_CACHE_quantization_param_path] @@ -76,7 +76,7 @@ optional arguments: --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. --model MODEL --tokenizer TOKENIZER - --quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None} + --quantization {awq,gptq,None}, -q {awq,gptq,None} --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE --n N Number of generated sequences per prompt. --use-beam-search diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe254732e7309..151cdbee8eb04 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# squeezellm -def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, - lookup_table: torch.Tensor) -> None: - torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) - - # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, diff --git a/vllm/config.py b/vllm/config.py index e513608eca9f8..1c9e30b2682b9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -277,7 +277,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "fp8"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", @@ -1537,7 +1537,7 @@ def verify_with_model_config(self, model_config: ModelConfig): if model_config.quantization and model_config.quantization not in [ "awq", "gptq" ]: - # TODO support marlin and squeezellm + # TODO support marlin logger.warning("%s quantization is not tested with LoRA yet.", model_config.quantization) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b32c90a4df1aa..f587ec3003141 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -55,7 +55,7 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq", "squeezellm", and "fp8" (experimental). + we support "awq", "gptq", and "fp8" (experimental). If None, we first check the `quantization_config` attribute in the model config file. If that is None, we assume the model weights are not quantized and use `dtype` to determine the data type of diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ea67991a375..b9ac498b23a7b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - # GPTQ/AWQ/SqueezeLLM + # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device # marlin diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index c6fb6ca0d2e01..aa5c288962d91 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.neuron_quant import ( NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -43,7 +42,6 @@ "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, - "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py deleted file mode 100644 index afb3c04976737..0000000000000 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, List, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip - - -class SqueezeLLMConfig(QuantizationConfig): - """Config class for SqueezeLLM. - - Reference: https://arxiv.org/pdf/2306.07629 - """ - - def __init__( - self, - weight_bits: int, - ) -> None: - self.weight_bits = weight_bits - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"SqueezeLLM, but got {self.weight_bits} bits.") - - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" - - def get_name(self) -> str: - return "squeezellm" - - def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - @staticmethod - def get_config_filenames() -> List[str]: - return ["quant_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - return cls(weight_bits) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional[QuantizeMethodBase]: - if isinstance(layer, LinearBase): - return SqueezeLLMLinearMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class SqueezeLLMLinearMethod(QuantizeMethodBase): - """Linear method for SqueezeLLM. - - Args: - quant_config: The SqueezeLLM quantization config. - """ - - def __init__(self, quant_config: SqueezeLLMConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - if input_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - output_size_per_partition = sum(output_partition_sizes) - qweight = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.pack_factor, - output_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }) - lookup_table = Parameter( - torch.empty( - output_size, - self.quant_config.weight_bits**2, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(lookup_table, { - "output_dim": 0, - }) - - layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) - layer.register_parameter("lookup_table", lookup_table) - set_weight_attrs(lookup_table, extra_weight_attrs) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = layer.qweight - lookup_table = layer.lookup_table - out_shape = x.shape[:-1] + (qweight.shape[-1], ) - reshaped_x = x.reshape(-1, x.shape[-1]) - if is_hip(): - out_f = torch.zeros(out_shape, dtype=torch.float) - ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) - out = out_f.to(dtype=torch.float16) - else: - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, dtype=torch.float16) - ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) - - if bias is not None: - out.add_(bias) - return out.reshape(out_shape) From 29f49cd6e3d3c5658b92ea3e97138c1ab5cb6b30 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 7 Sep 2024 01:02:05 +0200 Subject: [PATCH 36/77] [Model] Allow loading from original Mistral format (#8168) Co-authored-by: Michael Goin --- tests/models/test_mistral.py | 40 +++++ vllm/config.py | 62 ++++--- vllm/engine/arg_utils.py | 21 ++- vllm/model_executor/model_loader/loader.py | 12 +- .../model_loader/weight_utils.py | 21 +-- vllm/model_executor/models/llama.py | 51 ++++++ vllm/transformers_utils/config.py | 165 ++++++++++++++---- 7 files changed, 291 insertions(+), 81 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4965354c0016b..0741174497e32 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -41,3 +41,43 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", + ) as hf_format_model: + hf_format_outputs = hf_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + ) as mistral_format_model: + mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_format_outputs, + outputs_1_lst=mistral_format_outputs, + name_0="hf", + name_1="mistral", + ) diff --git a/vllm/config.py b/vllm/config.py index 1c9e30b2682b9..8f5e02e35f28d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,7 +13,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (get_config, +from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, @@ -121,35 +121,37 @@ class ModelConfig: override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm arguments. + config_format: The config format which shall be loaded. + Defaults to 'auto' which defaults to 'hf'. """ - def __init__( - self, - model: str, - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - override_neuron_config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None, + config_format: ConfigFormat = ConfigFormat.AUTO) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -176,7 +178,8 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta) + code_revision, rope_scaling, rope_theta, + config_format) self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) @@ -746,6 +749,7 @@ class LoadFormat(str, enum.Enum): SHARDED_STATE = "sharded_state" GGUF = "gguf" BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0b866db64324..7620093660b43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,10 +8,10 @@ import torch import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, +from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, + DeviceConfig, EngineConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger @@ -65,6 +65,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + config_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -234,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') parser.add_argument( '--dtype', type=str, @@ -813,7 +821,10 @@ def create_engine_config(self) -> EngineConfig: served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, - override_neuron_config=self.override_neuron_config) + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + ) + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 553fa848489b2..bcc866a194637 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,6 +17,7 @@ from huggingface_hub import HfApi, hf_hub_download from torch import nn from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -241,12 +242,17 @@ def _prepare_weights(self, model_name_or_path: str, is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] elif load_format == LoadFormat.NPCACHE: @@ -284,10 +290,10 @@ def _prepare_weights(self, model_name_or_path: str, # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( - model_name_or_path, self.load_config.download_dir, - revision) + model_name_or_path, index_file, + self.load_config.download_dir, revision) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder) + hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0666457756b02..075451292a8e4 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -16,7 +16,6 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_rank @@ -251,6 +250,7 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( model_name_or_path: str, + index_file: str, cache_dir: Optional[str], revision: Optional[str] = None, ) -> None: @@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf( # Download the safetensors index file. hf_hub_download( repo_id=model_name_or_path, - filename=SAFE_WEIGHTS_INDEX_NAME, + filename=index_file, cache_dir=cache_dir, revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) # If file not found on remote or locally, we should not fail since - # only some models will have SAFE_WEIGHTS_INDEX_NAME. + # only some models will have index_file. except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in remote.", index_file) except huggingface_hub.utils.LocalEntryNotFoundError: - logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in local cache.", index_file) # For models like Mistral-7B-v0.3, there are both sharded # safetensors files and a consolidated safetensors file. # Passing both of these to the weight loader functionality breaks. -# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# So, we use the index_file to # look up which safetensors files should be used. def filter_duplicate_safetensors_files(hf_weights_files: List[str], - hf_folder: str) -> List[str]: + hf_folder: str, + index_file: str) -> List[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. - index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + index_file_name = os.path.join(hf_folder, index_file) if not os.path.isfile(index_file_name): return hf_weights_files # Iterate through the weight_map (weight_name: safetensors files) # to identify weights that we should use. - with open(index_file_name) as index_file: - weight_map = json.load(index_file)["weight_map"] + with open(index_file_name, "r") as f: + weight_map = json.load(f)["weight_map"] weight_files_in_index = set() for weight_name in weight_map: weight_files_in_index.add( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e55c01316087c..5ff31e3833ec9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm" + } def __init__( self, @@ -472,6 +491,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -549,3 +570,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, name: str, + loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: + + def permute(w, n_heads): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + for item in modules: + if item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4f4e79d10a677..13fcf6b918603 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,12 +1,16 @@ import contextlib +import enum +import json from pathlib import Path from typing import Any, Dict, Optional, Type, Union +from huggingface_hub import file_exists, hf_hub_download from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -27,6 +31,8 @@ else: from transformers import AutoConfig +MISTRAL_CONFIG_NAME = "params.json" + logger = init_logger(__name__) _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { @@ -53,6 +59,20 @@ AutoConfig.register(name, cls) +class ConfigFormat(str, enum.Enum): + AUTO = "auto" + HF = "hf" + MISTRAL = "mistral" + + +def file_or_path_exists(model: Union[str, Path], config_name, revision, + token) -> bool: + if Path(model).exists(): + return (Path(model) / config_name).is_file() + + return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) + + def get_config( model: Union[str, Path], trust_remote_code: bool, @@ -60,45 +80,68 @@ def get_config( code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, **kwargs, ) -> PretrainedConfig: - # Separate model folder from file path for GGUF models + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name model = Path(model).parent - config_dict, _ = PretrainedConfig.get_config_dict( - model, revision=revision, code_revision=code_revision, **kwargs) + if config_format == ConfigFormat.AUTO: + if is_gguf or file_or_path_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.MISTRAL + else: + raise ValueError(f"No supported config format found in {model}") + + if config_format == ConfigFormat.HF: + config_dict, _ = PretrainedConfig.get_config_dict( + model, revision=revision, code_revision=code_revision, **kwargs) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) + elif config_format == ConfigFormat.MISTRAL: + config = load_params_config(model, revision) else: - try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model is a custom " - "model not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e + raise ValueError(f"Unsupported config format: {config_format}") # Special architecture mapping check for GGUF models if is_gguf: @@ -108,16 +151,70 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) - for key, value in [("rope_scaling", rope_scaling), - ("rope_theta", rope_theta)]: + for key, value in [ + ("rope_scaling", rope_scaling), + ("rope_theta", rope_theta), + ]: if value is not None: - logger.info("Updating %s from %r to %r", key, - getattr(config, key, None), value) + logger.info( + "Updating %s from %r to %r", + key, + getattr(config, key, None), + value, + ) config.update({key: value}) return config +def load_params_config(model, revision) -> PretrainedConfig: + # This function loads a params.json config which + # should be used when loading models in mistral format + + config_file_name = "params.json" + + config_path = Path(model) / config_file_name + + if not config_path.is_file(): + config_path = Path( + hf_hub_download(model, config_file_name, revision=revision)) + + with open(config_path, "r") as file: + config_dict = json.load(file) + + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + + def recurse_elems(elem: Any): + if isinstance(elem, dict): + config_dict = {} + for key, value in elem.items(): + key = config_mapping.get(key, key) + config_dict[key] = recurse_elems(value) + return PretrainedConfig(**config_dict) + else: + return elem + + config_dict["model_type"] = config_dict.get("model_type", "transformer") + config_dict["hidden_act"] = config_dict.get("activation", "silu") + config_dict["tie_word_embeddings"] = config_dict.get( + "tie_embeddings", False) + + if config_dict["model_type"] == "transformer": + if "moe" in config_dict: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + return recurse_elems(config_dict) + + def get_hf_image_processor_config( model: Union[str, Path], revision: Optional[str] = None, @@ -134,7 +231,7 @@ def get_hf_image_processor_config( def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. + No op for pure text models. """ if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have From 12dd715807ccbd7fafbb64d42571792db1cc6497 Mon Sep 17 00:00:00 2001 From: William Lin Date: Fri, 6 Sep 2024 17:48:48 -0700 Subject: [PATCH 37/77] [misc] [doc] [frontend] LLM torch profiler support (#7943) --- docs/source/dev/profiling/profiling_index.rst | 20 +++++++++-- examples/offline_inference_with_profiler.py | 33 +++++++++++++++++++ vllm/engine/llm_engine.py | 6 ++++ vllm/entrypoints/llm.py | 6 ++++ vllm/executor/cpu_executor.py | 6 ++++ vllm/executor/gpu_executor.py | 6 ++++ 6 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 examples/offline_inference_with_profiler.py diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index af3c78c3b5a55..e22d547293445 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -17,14 +17,28 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. - -Example commands: + +.. tip:: + + To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. + Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + +Example commands and usage: +=========================== + +Offline Inference: +------------------ + +Refer to `examples/offline_inference_with_profiler.py `_ for an example. + OpenAI Server: +-------------- .. code-block:: bash - VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B + VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B benchmark_serving.py: diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py new file mode 100644 index 0000000000000..906c9502800d8 --- /dev/null +++ b/examples/offline_inference_with_profiler.py @@ -0,0 +1,33 @@ +import os + +from vllm import LLM, SamplingParams + +# enable torch profiler, can also be set on cmd line +os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") + +llm.start_profile() + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) + +llm.stop_profile() + +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 50dcb6937eb6f..78ddcd1daaf69 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1914,6 +1914,12 @@ def check_health(self) -> None: self.tokenizer.check_health() self.model_executor.check_health() + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + def is_tracing_enabled(self) -> bool: return self.tracer is not None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f587ec3003141..1e4432eaaa665 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -560,6 +560,12 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 21ad43f641685..ec9b24ce1318f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -296,6 +296,12 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: for result in parallel_worker_tasks: result.get() + def start_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "start_profile") + + def stop_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "stop_profile") + class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 947776e5d6ef4..2185c9cf6cead 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -169,6 +169,12 @@ def check_health(self) -> None: # it's running. return + def start_profile(self) -> None: + self.driver_worker.start_profile() + + def stop_profile(self) -> None: + self.driver_worker.stop_profile() + class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): From 41e95c5247c9703c3e11f3b563d8bba70ed31aca Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 6 Sep 2024 21:49:01 -0500 Subject: [PATCH 38/77] [Bugfix] Fix Hermes tool call chat template bug (#8256) Co-authored-by: Kyle Mistele --- examples/tool_chat_template_hermes.jinja | 31 ++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index b18b463032d4f..0b0902c8e7497 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -89,22 +89,23 @@ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} {%- elif message.role == "assistant" and message.tool_calls is defined %} {{- '<|im_start|>' + message.role }} - {%- for tool_call in message.tool_calls %} - {{- '\n\n' }} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '{' }} - {{- '"name": "' }} - {{- tool_call.name }} - {{- '"}' }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {%- if tool_call.arguments is defined %} {{- ', ' }} - {%- if tool_call.arguments is defined %} - {{- '"arguments": ' }} - {{- tool_call.arguments|tojson }} - {%- endif %} - {{- '\n' }} - {%- endfor %} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '}' }} + {{- '\n' }} + {%- endfor %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {%- if loop.previtem and loop.previtem.role != "tool" %} From 2f707fcb35c5bc4b9164cf2bbce0254a72f7348b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Sep 2024 10:57:24 +0800 Subject: [PATCH 39/77] [Model] Multi-input support for LLaVA (#8238) --- docs/source/models/supported_models.rst | 16 +- tests/conftest.py | 12 +- .../distributed/test_multimodal_broadcast.py | 6 +- tests/models/test_llava.py | 141 ++++++++++++++++-- vllm/model_executor/models/clip.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/llava.py | 32 ++-- vllm/model_executor/models/llava_next.py | 4 +- vllm/model_executor/models/phi3v.py | 4 +- vllm/model_executor/models/siglip.py | 2 +- 10 files changed, 176 insertions(+), 45 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 0c0a54281e3f3..fe01e1681353e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -219,7 +219,7 @@ Multimodal Language Models - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image\ :sup:`E` + - Image\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - * - :code:`LlavaNextForConditionalGeneration` @@ -227,6 +227,11 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`MiniCPMV` + - MiniCPM-V + - Image\ :sup:`+` + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` @@ -237,14 +242,9 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - * - :code:`MiniCPMV` - - MiniCPM-V - - Image\ :sup:`+` - - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - - * - :code:`QWenLMHeadModel` - - Qwen - - Image + - Qwen-VL + - Image\ :sup:`E` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`UltravoxModel` diff --git a/tests/conftest.py b/tests/conftest.py index e66a14598c343..cd0091b7cba68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -278,7 +278,7 @@ def __init__( def generate( self, prompts: List[str], - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: if images: @@ -314,7 +314,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, @@ -351,7 +351,7 @@ def generate_greedy_logprobs( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: all_logprobs: List[List[torch.Tensor]] = [] @@ -433,8 +433,8 @@ def generate_greedy_logprobs_limit( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, - audios: Optional[List[Tuple[np.ndarray, int]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] @@ -671,7 +671,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index e7723a7ae2480..73ef863c2f193 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, if model.startswith("llava-hf/llava-1.5"): from ..models.test_llava import models, run_test elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import models, run_test + from ..models.test_llava_next import run_test # type: ignore[no-redef] + from ..models.test_llava_next import models elif model.startswith("facebook/chameleon"): - from ..models.test_chameleon import models, run_test + from ..models.test_chameleon import run_test # type: ignore[no-redef] + from ..models.test_chameleon import models else: raise NotImplementedError(f"Unsupported model: {model}") diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 9d7da5f803ea4..84ca23f6222a9 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, overload import pytest from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, @@ -8,11 +8,14 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm +_LIMIT_IMAGE_PER_PROMPT = 4 + HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "USER: \nWhat's the content of the image?\nASSISTANT:", @@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs +@overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -64,6 +68,78 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): """Inference result should be the same between hf and vllm. @@ -85,13 +161,6 @@ def run_test( else: mantis_processor = None - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -100,15 +169,18 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, + max_model_len=4096, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] if mantis_processor is not None: @@ -131,7 +203,7 @@ def process(hf_inputs: BatchEncoding): max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 4 images.\nASSISTANT:", # noqa: E501 + "USER: \nWhat is the season?\nASSISTANT:", + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize("model", models) def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b581a501e3333..70f1522ae2524 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -105,7 +105,7 @@ def input_processor_for_clip( if isinstance(image_data, Image.Image): image_feature_size = get_clip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d317fdce3ba68..10fbb5663d274 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = num_blocks * num_patches elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 43c485bdf3668..7a6c991fb133a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from vllm.attention import AttentionMetadata @@ -16,6 +17,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, @@ -24,7 +26,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) @@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config - image_feature_size = get_max_llava_image_tokens(ctx) + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_max_llava_image_tokens(ctx) + elif is_list_of(image_data, Image.Image): + image_feature_size = [get_max_llava_image_tokens(ctx) + ] * len(image_data) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( @@ -230,29 +243,24 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) - return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) - return LlavaImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds, concat=True), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5a179e9603710..c6bd46dd7eda9 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): for img in image_data ] elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index c449e0fc759a3..6f17f571ccaea 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): input_width=w, input_height=h)) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 0bee75e2f0cbb..fb4c30c1a13f9 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -110,7 +110,7 @@ def input_processor_for_siglip( if isinstance(image_data, Image.Image): image_feature_size = get_siglip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: From 795b662cffe79fa0fa9a3f13a65113abdb4f96a9 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 6 Sep 2024 20:18:16 -0700 Subject: [PATCH 40/77] Enable Random Prefix Caching in Serving Profiling Tool (benchmark_serving.py) (#8241) --- benchmarks/benchmark_serving.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bdfa81be4208e..9ba3f649810b7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -195,8 +195,16 @@ def sample_sonnet_requests( def sample_random_requests( - input_len: int, output_len: int, num_prompts: int, range_ratio: float, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]: + prefix_len: int, + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int]]: + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=prefix_len).tolist() input_lens = np.random.randint( int(input_len * range_ratio), @@ -211,10 +219,12 @@ def sample_random_requests( offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size + prompt = tokenizer.decode(prefix_token_ids + + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + input_requests.append( - (prompt, int(input_lens[i]), int(output_lens[i]))) + (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) return input_requests @@ -567,6 +577,7 @@ def main(args: argparse.Namespace): elif args.dataset_name == "random": input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, @@ -765,6 +776,14 @@ def main(args: argparse.Namespace): help="Range of sampled ratio of input/output length, " "used only for random sampling.", ) + parser.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, From ce2702a92356b69ec1ea35ecd46263ddf98e8e2c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Sep 2024 22:40:46 -0700 Subject: [PATCH 41/77] [tpu][misc] fix typo (#8260) --- tests/compile/test_wrapper.py | 4 ++-- vllm/compilation/wrapper.py | 2 +- vllm/worker/tpu_model_runner.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index cef516ade27eb..3668c1fab6b89 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher class MyMod(torch.nn.Module): @@ -13,7 +13,7 @@ def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): return x * 2 -class MyWrapper(TorchCompileWrapperWithCustomDispacther): +class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c3d863299dd06..e923bd36ccc08 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -10,7 +10,7 @@ import vllm.envs as envs -class TorchCompileWrapperWithCustomDispacther: +class TorchCompileWrapperWithCustomDispatcher: """ A wrapper class for torch.compile, with a custom dispatch logic. Subclasses should: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 684c54b7d8139..db306bc743d3a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -11,7 +11,7 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger @@ -611,7 +611,7 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(TorchCompileWrapperWithCustomDispacther): +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model: nn.Module): self.model = model From 9f68e00d27b0f8252549be3adbb47c5b735a8103 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Sep 2024 16:02:39 +0800 Subject: [PATCH 42/77] [Bugfix] Fix broken OpenAI tensorizer test (#8258) --- tests/utils.py | 12 ++-- vllm/engine/arg_utils.py | 72 ++++++++++--------- vllm/model_executor/model_loader/loader.py | 30 +++++++- .../model_executor/model_loader/tensorizer.py | 7 ++ 4 files changed, 81 insertions(+), 40 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 04067ef372ac2..6e5bc05b3901a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip @@ -89,11 +89,11 @@ def __init__(self, is_local = os.path.isdir(model) if not is_local: engine_args = AsyncEngineArgs.from_cli_args(args) - engine_config = engine_args.create_engine_config() - dummy_loader = DefaultModelLoader(engine_config.load_config) - dummy_loader._prepare_weights(engine_config.model_config.model, - engine_config.model_config.revision, - fall_back_to_pt=True) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) env = os.environ.copy() # the current process might initialize cuda, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7620093660b43..9bc03948d3845 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -771,33 +771,8 @@ def from_cli_args(cls, args: argparse.Namespace): engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args - def create_engine_config(self) -> EngineConfig: - # gguf file needs a specific model loader and doesn't use hf_repo - if check_gguf_file(self.model): - self.quantization = self.load_format = "gguf" - - # bitsandbytes quantization needs a specific model loader - # so we make sure the quant method and the load format are consistent - if (self.quantization == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.load_format != "bitsandbytes": - raise ValueError( - "BitsAndBytes quantization and QLoRA adapter only support " - f"'bitsandbytes' load format, but got {self.load_format}") - - if (self.load_format == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.quantization != "bitsandbytes": - raise ValueError( - "BitsAndBytes load format and QLoRA adapter only support " - f"'bitsandbytes' quantization, but got {self.quantization}") - - assert self.cpu_offload_gb >= 0, ( - "CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") - - device_config = DeviceConfig(device=self.device) - model_config = ModelConfig( + def create_model_config(self) -> ModelConfig: + return ModelConfig( model=self.model, tokenizer=self.tokenizer, tokenizer_mode=self.tokenizer_mode, @@ -825,6 +800,42 @@ def create_engine_config(self) -> EngineConfig: config_format=self.config_format, ) + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> EngineConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization and QLoRA adapter only support " + f"'bitsandbytes' load format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format and QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, ( + "CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + device_config = DeviceConfig(device=self.device) + model_config = self.create_model_config() + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len @@ -967,12 +978,7 @@ def create_engine_config(self) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path - load_config = LoadConfig( - load_format=self.load_format, - download_dir=self.download_dir, - model_loader_extra_config=self.model_loader_extra_config, - ignore_patterns=self.ignore_patterns, - ) + load_config = self.create_load_config() prompt_adapter_config = PromptAdapterConfig( max_prompt_adapters=self.max_prompt_adapters, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bcc866a194637..f59eb805ea907 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -185,6 +185,11 @@ class BaseModelLoader(ABC): def __init__(self, load_config: LoadConfig): self.load_config = load_config + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + @abstractmethod def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, @@ -193,7 +198,7 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" - ... + raise NotImplementedError class DefaultModelLoader(BaseModelLoader): @@ -335,6 +340,11 @@ def _xla_weights_iterator(iterator: Generator): weights_iterator = _xla_weights_iterator(weights_iterator) return weights_iterator + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -377,6 +387,9 @@ def __init__(self, load_config: LoadConfig): raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -467,6 +480,12 @@ def _load_model_serialized( model = load_with_tensorizer(tensorizer_config, **extra_kwargs) return model.eval() + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -568,6 +587,9 @@ def _prepare_weights(self, model_name_or_path: str, ignore_patterns=self.load_config.ignore_patterns, ) + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -995,6 +1017,9 @@ def _load_weights(self, model_config: ModelConfig, set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -1070,6 +1095,9 @@ def _get_weights_iterator( return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index b009ad8c882d4..3aac5cd2b43a5 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -99,6 +99,13 @@ def verify_with_model_config(self, model_config: "ModelConfig") -> None: "Loading a model using Tensorizer with quantization on vLLM" " is unstable and may lead to errors.") + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_params) + def load_with_tensorizer(tensorizer_config: TensorizerConfig, **extra_kwargs) -> nn.Module: From e807125936a9db796746b67ba72c222b5c26582e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 7 Sep 2024 16:38:23 +0800 Subject: [PATCH 43/77] [Model][VLM] Support multi-images inputs for InternVL2 models (#8201) --- docs/source/models/supported_models.rst | 2 +- ...e_inference_vision_language_multi_image.py | 94 +++++++++++++++---- tests/models/test_internvl.py | 92 ++++++++++++++---- tests/models/test_phi3v.py | 8 +- vllm/model_executor/models/internvl.py | 60 +++++++++--- 5 files changed, 199 insertions(+), 57 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index fe01e1681353e..1bb3a448f2c92 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -214,7 +214,7 @@ Multimodal Language Models - * - :code:`InternVLChatModel` - InternVL2 - - Image\ :sup:`E` + - Image\ :sup:`E+` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - * - :code:`LlavaForConditionalGeneration` diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 73543ab5da2b4..dd84627b9dc58 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -6,7 +6,9 @@ from argparse import Namespace from typing import List -from vllm import LLM +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -17,36 +19,84 @@ ] -def _load_phi3v(image_urls: List[str]): - return LLM( +def load_phi3v(question, image_urls: List[str]): + llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, ) - - -def run_phi3v_generate(question: str, image_urls: List[str]): - llm = _load_phi3v(image_urls) - placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" + stop_token_ids = None + return llm, prompt, stop_token_ids - outputs = llm.generate({ - "prompt": prompt, - "multi_modal_data": { - "image": [fetch_image(url) for url in image_urls] + +def load_internvl(question, image_urls: List[str]): + model_name = "OpenGVLab/InternVL2-2B" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_num_seqs=5, + max_model_len=4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B#service + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids + + +model_example_map = { + "phi3_v": load_phi3v, + "internvl_chat": load_internvl, +} + + +def run_generate(model, question: str, image_urls: List[str]): + llm, prompt, stop_token_ids = model_example_map[model](question, + image_urls) + + sampling_params = SamplingParams(temperature=0.0, + max_tokens=128, + stop_token_ids=stop_token_ids) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in image_urls] + }, }, - }) + sampling_params=sampling_params) for o in outputs: generated_text = o.outputs[0].text print(generated_text) -def run_phi3v_chat(question: str, image_urls: List[str]): - llm = _load_phi3v(image_urls) +def run_chat(model: str, question: str, image_urls: List[str]): + llm, _, stop_token_ids = model_example_map[model](question, image_urls) + + sampling_params = SamplingParams(temperature=0.0, + max_tokens=128, + stop_token_ids=stop_token_ids) outputs = llm.chat([{ "role": @@ -63,7 +113,8 @@ def run_phi3v_chat(question: str, image_urls: List[str]): }, } for image_url in image_urls), ], - }]) + }], + sampling_params=sampling_params) for o in outputs: generated_text = o.outputs[0].text @@ -71,12 +122,13 @@ def run_phi3v_chat(question: str, image_urls: List[str]): def main(args: Namespace): + model = args.model_type method = args.method if method == "generate": - run_phi3v_generate(QUESTION, IMAGE_URLS) + run_generate(model, QUESTION, IMAGE_URLS) elif method == "chat": - run_phi3v_chat(QUESTION, IMAGE_URLS) + run_chat(model, QUESTION, IMAGE_URLS) else: raise ValueError(f"Invalid method: {method}") @@ -85,6 +137,12 @@ def main(args: Namespace): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'vision language models that support multi-image input') + parser.add_argument('--model-type', + '-m', + type=str, + default="phi3_v", + choices=model_example_map.keys(), + help='Huggingface "model_type".') parser.add_argument("--method", type=str, default="generate", diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 42732cebc6567..fa3369dc53345 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -1,5 +1,5 @@ import types -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, Union import pytest import torch @@ -9,7 +9,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -20,6 +21,7 @@ "cherry_blossom": "<|im_start|>User\n\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 }) +HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: \nImage-2: \nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 models = [ "OpenGVLab/InternVL2-1B", @@ -64,13 +66,13 @@ def generate( def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -83,12 +85,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -110,13 +106,21 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Image, **kwargs): + def __call__(self, text: str, images: Union[Image, List[Image]], + **kwargs): from vllm.model_executor.models.internvl import ( IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) - pixel_values = image_to_pixel_values( - images, self.image_size, self.min_num, self.max_num, - self.use_thumbnail).to(self.dtype) - num_patches_list = [pixel_values.shape[0]] + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values(image, self.image_size, self.min_num, + self.max_num, + self.use_thumbnail).to(self.dtype) + for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: context_tokens = IMG_CONTEXT * self.num_image_token \ * num_patches @@ -130,6 +134,7 @@ def __call__(self, text: str, images: Image, **kwargs): with vllm_runner(model, max_model_len=4096, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: @@ -138,7 +143,7 @@ def __call__(self, text: str, images: Image, **kwargs): max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype) as hf_model: @@ -156,7 +161,7 @@ def __call__(self, text: str, images: Image, **kwargs): num_logprobs=num_logprobs, images=hf_images, eos_token_id=eos_token_id) - for prompts, hf_images in inputs_per_image + for prompts, hf_images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -264,15 +269,64 @@ def run_awq_test( @torch.inference_mode() def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - image_assets, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.5, 0.75, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + run_test( + hf_runner, + vllm_runner, + inputs_per_case, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index e416a85b8962a..6ecbf07a08b7c 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -1,16 +1,15 @@ import os import re -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type import pytest -from PIL import Image from transformers import AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner +from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -60,8 +59,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], Union[List[Image.Image], - List[List[Image.Image]]]]], + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, dtype: str, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 10fbb5663d274..0cf63d9e1fb22 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,6 +5,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import itertools +import re from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -26,6 +27,7 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) @@ -95,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, - max_num: int, - image_size: int) -> Tuple[int, int, int]: + max_num: int, image_size: int, + use_thumbnail: bool) -> Tuple[int, int, int]: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio @@ -114,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 return blocks, target_width, target_height # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, image_size: int, - use_thumbnail: int) -> List[Image.Image]: + use_thumbnail: bool) -> List[Image.Image]: orig_width, orig_height = image.size + # calculate the number of blocks without thumbnail blocks, target_width, target_height = calculate_num_blocks( - orig_width, orig_height, min_num, max_num, image_size) + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] @@ -197,17 +208,23 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): downsample_ratio) image_data = multi_modal_data["image"] + min_num = hf_config.min_dynamic_patch + max_num = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail if isinstance(image_data, Image.Image): width, height = image_data.size - min_num = hf_config.min_dynamic_patch - max_num = hf_config.max_dynamic_patch num_blocks, _, _ = calculate_num_blocks(width, height, min_num, - max_num, image_size) - # add thumbnail image if num_blocks > 1 - if hf_config.use_thumbnail and num_blocks > 1: - num_blocks += 1 - image_feature_size = num_blocks * num_patches - + max_num, image_size, + use_thumbnail) + image_feature_size = [num_blocks * num_patches] + elif is_list_of(image_data, Image.Image): + image_feature_size = [] + for image in image_data: + width, height = image.size + num_blocks, _, _ = calculate_num_blocks(width, height, min_num, + max_num, image_size, + use_thumbnail) + image_feature_size.append(num_blocks * num_patches) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape else: @@ -220,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): prompt_token_ids = llm_inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END - new_prompt = prompt.replace('', image_prompt, 1) + + new_prompt = prompt + image_idx = sorted(map(int, re.findall(r"Image-(\d+): \n", prompt))) + for idx, feature_size in enumerate(image_feature_size, start=1): + image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END + if not image_idx: + image_prompt = f"Image-{idx}: {image_prompt}" + new_prompt = new_prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) return LLMInputs(prompt=prompt, @@ -245,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): use_thumbnail=use_thumbnail) # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) + elif is_list_of(data, Image.Image): + data = [ + image_to_pixel_values(img, + image_size, + min_num, + max_num, + use_thumbnail=use_thumbnail) for img in data + ] + data = torch.stack(data) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) From 8886423085ba84db2cea64dd24502620f7904009 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sat, 7 Sep 2024 06:30:32 -0400 Subject: [PATCH 44/77] Move float16 typecast hack to gptq marlin moe method --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 +++ vllm/model_executor/models/mixtral.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a73c462c148c2..1691139bedab6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -584,6 +584,9 @@ def apply( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) + # The input must currently be float16 + x = x.half() + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 148ef393277e4..df7f39097bdc6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -99,7 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states.half(), router_logits) + final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape).to(orig_dtype) From ab274976a52486b6bf41c93b36d2e8fd62af91c2 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sat, 7 Sep 2024 06:58:44 -0400 Subject: [PATCH 45/77] Move output type conversion to gptq method as well --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 ++- vllm/model_executor/models/mixtral.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1691139bedab6..33899f1fb6716 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -585,6 +585,7 @@ def apply( fused_marlin_moe) # The input must currently be float16 + orig_dtype = x.dtype x = x.half() topk_weights, topk_ids = FusedMoE.select_experts( @@ -610,4 +611,4 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, - ) + ).to(orig_dtype) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index df7f39097bdc6..6413b56605ecf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,12 +95,11 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - orig_dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape).to(orig_dtype) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): From 36bf8150cc3a048d69d9d2196128462014b9599d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Sep 2024 01:45:44 +0800 Subject: [PATCH 46/77] [Model][VLM] Decouple weight loading logic for `Paligemma` (#8269) --- vllm/model_executor/models/paligemma.py | 112 ++++++++---------------- vllm/model_executor/models/siglip.py | 23 ++++- 2 files changed, 54 insertions(+), 81 deletions(-) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b6f4275fbc948..5fd39b5e35be6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,3 +1,4 @@ +import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -13,7 +14,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.gemma import GemmaModel +from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer @@ -22,14 +23,10 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import merge_multimodal_embeddings +from .utils import filter_weights, merge_multimodal_embeddings logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.model": "language_model", -} - class PaliGemmaImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -151,8 +148,8 @@ def __init__(self, projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config - self.language_model = GemmaModel(config.text_config, cache_config, - quant_config) + self.language_model = GemmaForCausalLM(config.text_config, + cache_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -252,7 +249,8 @@ def forward(self, vision_embeddings = vision_embeddings * (self.config.hidden_size** -0.5) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -262,87 +260,47 @@ def forward(self, else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states - # Copied from vllm/model_executor/models/gemma.py def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.embed_tokens, - hidden_states, sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) - # Copied from vllm/model_executor/models/gemma.py def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) - # Adapted from vllm/model_executor/models/gemma.py def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" not in name or self.vision_tower.shard_weight: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True - else: - use_default_weight_loading = True - - if use_default_weight_loading: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - loaded_params.add(name) - - unloaded_params = params_dict.keys() - loaded_params - if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", - unloaded_params) + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision tower + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index fb4c30c1a13f9..13d09e4cd4c23 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -529,6 +529,12 @@ def forward( ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -544,7 +550,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From b962ee1470a019a72a1c17eddcf3a0471658a123 Mon Sep 17 00:00:00 2001 From: sumitd2 <91451282+sumitd2@users.noreply.github.com> Date: Sat, 7 Sep 2024 23:48:40 +0530 Subject: [PATCH 47/77] ppc64le: Dockerfile fixed, and a script for buildkite (#8026) --- .buildkite/run-cpu-test-ppc64le.sh | 32 ++++++++++++++++++++++++++++++ Dockerfile.ppc64le | 16 ++++++++++----- 2 files changed, 43 insertions(+), 5 deletions(-) create mode 100755 .buildkite/run-cpu-test-ppc64le.sh diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh new file mode 100755 index 0000000000000..a01cf3fe67489 --- /dev/null +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -0,0 +1,32 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image, setting --shm-size=4g for tensor parallel. +#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test cpu-test + +# Run basic model test +docker exec cpu-test bash -c " + pip install pytest matplotlib einops transformers_stream_generator + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + +# online inference +docker exec cpu-test bash -c " + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index d4e4c483cada8..16780f8ab950c 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -2,21 +2,27 @@ FROM mambaorg/micromamba ARG MAMBA_DOCKERFILE_ACTIVATE=1 USER root -RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" + +RUN apt-get update -y && apt-get install -y git wget vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +RUN pip install -v cmake torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install -WORKDIR /vllm-workspace -ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + From cfe712bf1aedbee4f26105737710ff80ae9d624e Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Sat, 7 Sep 2024 14:03:16 -0600 Subject: [PATCH 48/77] [CI/Build] Use python 3.12 in cuda image (#8133) Signed-off-by: Joe Runde --- Dockerfile | 8 ++++++-- requirements-common.txt | 1 + tests/test_logger.py | 6 +++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2375e3f4d7387..0ec6655ed449e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ARG CUDA_VERSION=12.4.1 # prepare basic build environment FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies @@ -133,7 +133,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive @@ -179,6 +179,10 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) +# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels +# This installation must complete before the test dependencies are collected and installed. +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install "setuptools>=74.1.1" RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt diff --git a/requirements-common.txt b/requirements-common.txt index e430753357ca0..49a290317f818 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -27,3 +27,4 @@ gguf == 0.9.1 importlib_metadata mistral_common >= 1.3.4 pyyaml +six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/tests/test_logger.py b/tests/test_logger.py index 29346cd0878b8..8f3d218416870 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -95,7 +95,7 @@ def test_logger_configuring_can_be_disabled(): config behavior, however mocks are used to ensure no changes in behavior or configuration occur.""" - with patch("logging.config.dictConfig") as dict_config_mock: + with patch("vllm.logger.dictConfig") as dict_config_mock: _configure_vllm_root_logger() dict_config_mock.assert_not_called() @@ -175,9 +175,9 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): logging_config_file.flush() with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name), patch( - "logging.config.dictConfig") as dict_config_mock: + "vllm.logger.dictConfig") as dict_config_mock: _configure_vllm_root_logger() - assert dict_config_mock.called_with(valid_logging_config) + dict_config_mock.assert_called_with(valid_logging_config) @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) From 4ef41b84766670c1bd8079f58d35bf32b5bcb3ab Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Sun, 8 Sep 2024 00:01:51 -0400 Subject: [PATCH 49/77] [Bugfix] Fix async postprocessor in case of preemption (#8267) --- vllm/core/scheduler.py | 87 ++++++++------- vllm/engine/async_llm_engine.py | 24 ++-- vllm/engine/llm_engine.py | 149 ++++++++++++++++--------- vllm/worker/multi_step_model_runner.py | 26 +++-- 4 files changed, 172 insertions(+), 114 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 81c78bda3b505..c3fa95f57b737 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -537,13 +537,6 @@ def _schedule_running( preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out - # NOTE(woosuk): Preemption happens only when there is no available slot - # to keep all the sequence groups in the RUNNING state. - - # Store original running requests for the case of async + preemption - if self.use_async_output_proc: - orig_running = self.running.copy() - running_queue = self.running assert len(self._async_stopped) == 0 while running_queue: @@ -552,6 +545,7 @@ def _schedule_running( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) if num_running_tokens == 0: + # No budget => Stop break running_queue.popleft() @@ -565,18 +559,8 @@ def _schedule_running( self._async_stopped.append(seq_group) continue - # With async postprocessor, when preemption kicks in, we need - # first to drain the async postprocessor, so that all async - # block_table freeing is applied before the preemption freeing - # is applied. - if self.use_async_output_proc and not self._can_append_slots( - seq_group): - tmp = self.running - self.running = orig_running - assert self.output_proc_callback is not None - self.output_proc_callback() - self.running = tmp - + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) @@ -588,24 +572,43 @@ def _schedule_running( and seq_group.lora_int_id in curr_loras): curr_loras.remove(seq_group.lora_int_id) + # Determine victim sequence + cont_loop = True if running_queue: - # Preempt the lowest-priority sequence groups. + # Preempt the lowest-priority sequence group. victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: preempted_mode = self._preempt(victim_seq_group, blocks_to_swap_out) if preempted_mode == PreemptionMode.RECOMPUTE: preempted.append(victim_seq_group) else: swapped_out.append(victim_seq_group) - else: - # No other sequence groups can be preempted. - # Preempt the current sequence group. - preempted_mode = self._preempt(seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(seq_group) - else: - swapped_out.append(seq_group) + + if not cont_loop: break else: self._append_slots(seq_group, blocks_to_copy) @@ -1264,22 +1267,26 @@ def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: if seq.is_finished(): self.free_seq(seq) + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - else: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): remaining.append(seq_group) - # Free finished seqs - self._free_finished_seqs(seq_group) - self.running = remaining # Handle async stopped sequence groups diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..6ed1a6bba08ea 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -342,17 +342,17 @@ async def step_async( virtual_engine] # Execute the model. - output = await self.model_executor.execute_model_async( + outputs = await self.model_executor.execute_model_async( execute_model_req) # we need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - output = [] + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -365,25 +365,25 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - is_async = allow_async_output_proc - is_last_step = True - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step)) + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) - if output and allow_async_output_proc: + if outputs and allow_async_output_proc: assert len( - output + outputs ) == 1, "Async postprocessor expects only a single output set" self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 78ddcd1daaf69..94271c4a93151 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,9 +2,9 @@ import time from collections import deque from contextlib import contextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, - Mapping, Optional) + Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -90,17 +90,36 @@ class SchedulerOutputState: last_output: Optional[SamplerOutput] = None -@dataclass +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + skip: List[int] + + class SchedulerContext: - output_queue: Deque[Tuple[Optional[List[SamplerOutput]], - List[SequenceGroupMetadata], SchedulerOutputs, - bool, - bool]] = field(default_factory=lambda: deque()) - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = field( - default_factory=lambda: []) - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None + + def __init__(self): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + skip=[])) class LLMEngine: @@ -1246,23 +1265,15 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, ctx: SchedulerContext) -> None: - """Apply the model output to the sequences in the scheduled seq groups. + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. - virtual_engine: The engine id to operate on + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed - is_async: Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) - - sampler_output: Used with multi-step execution to provide - sampler_output of each step - is_last_output: Used with multi-step execution to indicate - the last step (of each multi-step group) - - Returns RequestOutputs that can be returned to the client. """ now = time.time() @@ -1270,9 +1281,14 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: return None # Get pending async postprocessor - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step) = ctx.output_queue.popleft() - assert outputs is not None + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1286,9 +1302,30 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: else: outputs_by_sequence_group = outputs + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + finished_before: List[int] = [] finished_now: List[int] = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group @@ -1343,6 +1380,18 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + # Free currently finished requests if finished_now: for scheduler in self.scheduler: @@ -1354,17 +1403,16 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: if (finished_now and self.process_request_outputs_callback is not None): self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() return # Create the outputs - # Note: scheduled_seq_groups and seq_group_metadata_list - # must match with the indices - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - - if i in finished_before or i in finished_now: + for i in indices: + if i in skip or i in finished_before or i in finished_now: continue # Avoids double processing + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) if (seq_group.is_finished() @@ -1380,6 +1428,7 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: if (ctx.request_outputs and self.process_request_outputs_callback is not None): self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() # For async case, we need to record the stats here. # For non-async case, the stats are done in the @@ -1548,20 +1597,20 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - output = self.model_executor.execute_model( + outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) # No outputs in this case - output = [] + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -1574,18 +1623,18 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.cached_scheduler_outputs[0] = SchedulerOutputState() # Add results to the output_queue - is_async = allow_async_output_proc - is_last_step = True - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step)) - - if output and allow_async_output_proc: - assert len(output) == 1, ( + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( "Async postprocessor expects only a single output set") self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path @@ -1593,7 +1642,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b52f2a07e344e..b13cf39bd846e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -274,12 +274,13 @@ def _async_process_outputs(self, model_input: StatefulModelInput, self.pinned_sampled_token_ids) if model_output.pythonized: ctx = output_proc_callback.keywords["ctx"] - is_async = False - is_last_step = False - ctx.output_queue.append( - ([model_output.sampler_output - ], ctx.seq_group_metadata_list, - ctx.scheduler_outputs, is_async, is_last_step)) + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + output_proc_callback() else: cont = False @@ -319,12 +320,13 @@ def _final_process_outputs(self, model_input: StatefulModelInput, if not is_last_step: ctx = output_proc_callback.keywords[ # type: ignore "ctx"] # type: ignore - is_async = False - is_last_step = False - ctx.output_queue.append( - ([output.sampler_output - ], ctx.seq_group_metadata_list, - ctx.scheduler_outputs, is_async, is_last_step)) + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) else: outputs.append(output.sampler_output) else: From 847e8602334de1f8202cdae240cb139518b0f478 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:07:15 -0400 Subject: [PATCH 50/77] Enable 8-bit weights in Fused Marlin MoE --- csrc/moe/marlin_moe_ops.cu | 301 ++++++++++++------ csrc/moe/marlin_moe_ops.h | 9 +- csrc/moe/torch_bindings.cpp | 11 +- tests/kernels/test_moe.py | 225 ++++++++++++- vllm/_custom_ops.py | 2 +- .../layers/fused_moe/__init__.py | 16 +- .../layers/fused_moe/fused_moe.py | 138 ++------ .../layers/fused_moe/fused_moe_marlin.py | 245 ++++++++++++++ .../compressed_tensors_moe.py | 34 +- .../layers/quantization/utils/marlin_utils.py | 17 + .../quantization/utils/marlin_utils_test.py | 11 +- .../layers/quantization/utils/quant_utils.py | 19 +- 12 files changed, 775 insertions(+), 253 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_marlin.py diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f70..e3c18ce5a50b8 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; @@ -840,10 +893,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +917,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +941,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1095,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1169,25 +1231,67 @@ __device__ inline void MarlinMoESingle( // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1331,8 @@ __device__ inline void MarlinMoESingle( } } -template ( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1447,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -1494,42 +1601,43 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1611,10 +1719,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1645,10 +1756,14 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1785,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1854,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850d..adee8399a4d6f 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b5..d2352375de33c 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -9,16 +9,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); -#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); - + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); -#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f526c381b3339..f7642bf02b05a 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +from typing import List + import pytest import torch from transformers import MixtralConfig @@ -9,7 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.scalar_type import scalar_types def torch_moe(a, w1, w2, score, topk): @@ -29,6 +36,20 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -43,11 +64,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -99,3 +120,199 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + torch.manual_seed(7) + + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + for i in range(w2.shape[0]): + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref1_l.append(w_ref1) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) + sort_indices1 = stack_and_dev(sort_indices1_l) + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref2_l.append(w_ref2) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) + sort_indices2 = stack_and_dev(sort_indices2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_moe_marlin( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk, + renormalize=False, + w1_scale=scales1, + w2_scale=scales2, + num_bits=num_bits, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_marlin_moe_mmm( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, + num_bits: int, +): + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == k: + return + + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(w.shape[0]): + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) + sort_indices = stack_and_dev(sort_indices_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + marlin_output = single_moe_marlin(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 151cdbee8eb04..77c46584ef530 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -308,7 +308,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042e..65a9b78a118c3 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,17 +1,23 @@ +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "fused_moe_marlin", + "single_moe_marlin", +] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ - "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 05169eaddb256..bd13d8fecbb96 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } + # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, @@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): if override_config: config = override_config else: @@ -391,6 +399,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) + ops.topk_softmax( topk_weights, topk_ids, @@ -437,113 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - custom_routing_function: Optional[Callable] = None, - renormalize: bool = True, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - if custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py new file mode 100644 index 0000000000000..40f9f66f1706b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -0,0 +1,245 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types + +from .fused_moe import (fused_topk, moe_align_block_size, + try_get_optimal_moe_config) + + +def single_moe_marlin( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + rand_perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Marlin MoE MMM using weights w + and top-k gating mechanism. It is meant for testing and debugging. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w (torch.Tensor): The first set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + product for w. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // (num_bits // 2) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, rand_perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_moe_marlin( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert num_bits in [4, 8] + # TODO support this + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + scalar_type, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + scalar_type, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 36323493d601e..abdc28bfebcc7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -269,19 +269,21 @@ def apply( custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_marlin_moe) - - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - custom_routing_function, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin) + + return fused_moe_marlin( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, + ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..699d5f1844146 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 7d08ac6f87469..4a06c5d63d52d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,6 +1,6 @@ """Utility functions used for tests and benchmarks""" -from typing import List +from typing import List, Optional import numpy as np import torch @@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, - act_order: bool): +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order) + w, quant_type, group_size, act_order, test_perm) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 33f24ff5d54d3..bdfda31de852b 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List +from typing import List, Optional import numpy import torch @@ -53,7 +53,10 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): assert q_w.shape == w_ref.shape orig_device = q_w.device @@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K - rand_perm = torch.randperm(k_size) + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() @@ -164,8 +167,11 @@ def reshape_w(w): ) -def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, act_order: bool): +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" @@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) return w_ref, w_q, w_s, g_idx, rand_perm From 430a9cb0f3c61702fbfeb8c59a7fdaac44344ae8 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:36:33 -0400 Subject: [PATCH 51/77] fix rocm --- csrc/moe/torch_bindings.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d2352375de33c..e4fce091d24a3 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -9,6 +9,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +#ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " @@ -19,5 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); } +#endif REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From 48047aae2510b6e5de588032797c4cc4059650fc Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 30 Aug 2024 09:45:52 -0400 Subject: [PATCH 52/77] bad paste --- csrc/moe/torch_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index e4fce091d24a3..cd65a8ee92b94 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -19,7 +19,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); -} #endif +} REGISTER_EXTENSION(TORCH_EXTENSION_NAME) From bfc4faed9562603fcc71c92d2c9fc293d9cc2130 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 30 Aug 2024 17:29:42 +0000 Subject: [PATCH 53/77] add test case; fix imports for tests --- tests/weight_loading/models.txt | 1 + vllm/model_executor/layers/fused_moe/__init__.py | 8 ++++---- vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 5 ++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98e..5eee2cc534445 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 65a9b78a118c3..06bd2706d7e4c 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,3 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -8,16 +6,18 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "fused_moe_marlin", - "single_moe_marlin", ] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_moe_marlin) __all__ += [ + "fused_moe_marlin", + "single_moe_marlin", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 40f9f66f1706b..40b409ebeb349 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -5,11 +5,10 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types -from .fused_moe import (fused_topk, moe_align_block_size, - try_get_optimal_moe_config) - def single_moe_marlin( hidden_states: torch.Tensor, From c5a2f6282cd60fa158f23536399dbbc98896bc63 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 30 Aug 2024 20:12:47 +0000 Subject: [PATCH 54/77] fix to adapt custom_routin_function --- .../layers/fused_moe/fused_moe_marlin.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 40b409ebeb349..8c49333f7c845 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -1,6 +1,6 @@ """Fused MoE utilities for GPTQ.""" import functools -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional import torch @@ -106,7 +106,8 @@ def fused_moe_marlin( rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, topk: int, - renormalize: bool, + custom_routing_function: Optional[Callable] = None, + renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -161,8 +162,12 @@ def fused_moe_marlin( E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) get_config_func = functools.partial( try_get_optimal_moe_config, From 2b308c469a446aca61aa225867012fdef1513168 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 2 Sep 2024 03:04:07 -0400 Subject: [PATCH 55/77] Use select_experts to compute top_k tensors in fused moe --- tests/kernels/test_moe.py | 7 ++++++- .../layers/fused_moe/fused_moe_marlin.py | 11 +++-------- .../compressed_tensors/compressed_tensors_moe.py | 15 +++++++++++++-- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f7642bf02b05a..2cfd76d1c780e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -218,6 +219,9 @@ def test_fused_marlin_moe( sort_indices2 = stack_and_dev(sort_indices2_l) score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + triton_output = fused_moe( a, w_ref1.transpose(1, 2).contiguous(), @@ -235,7 +239,8 @@ def test_fused_marlin_moe( g_idx2, sort_indices1, sort_indices2, - topk, + topk_weights, + topk_ids, renormalize=False, w1_scale=scales1, w2_scale=scales2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 8c49333f7c845..45dead9740f40 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -105,7 +105,8 @@ def fused_moe_marlin( g_idx2: torch.Tensor, rand_perm1: torch.Tensor, rand_perm2: torch.Tensor, - topk: int, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, custom_routing_function: Optional[Callable] = None, renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, @@ -161,13 +162,7 @@ def fused_moe_marlin( M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 - - if custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize) + topk = topk_ids.shape[1] get_config_func = functools.partial( try_get_optimal_moe_config, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index abdc28bfebcc7..53769cb73153b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -5,7 +5,7 @@ import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -272,6 +272,16 @@ def apply( from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( fused_moe_marlin) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + return fused_moe_marlin( x, layer.w13_weight_packed, @@ -281,7 +291,8 @@ def apply( layer.w2_g_idx, layer.w13_g_idx_sort_indices, layer.w2_g_idx_sort_indices, - top_k, + topk_weights, + topk_ids, renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, From 71256d45a491b896699416e74df87751ae1cdfc3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 10:42:10 -0400 Subject: [PATCH 56/77] bring back fused_moe_marlin -> fused_marlin_moe --- tests/kernels/test_moe.py | 8 ++++---- vllm/model_executor/layers/fused_moe/__init__.py | 8 ++++---- .../{fused_moe_marlin.py => fused_marlin_moe.py} | 4 ++-- .../compressed_tensors/compressed_tensors_moe.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) rename vllm/model_executor/layers/fused_moe/{fused_moe_marlin.py => fused_marlin_moe.py} (99%) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2cfd76d1c780e..6069978439824 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -230,7 +230,7 @@ def test_fused_marlin_moe( topk, renormalize=False, ) - marlin_output = fused_moe_marlin( + marlin_output = fused_marlin_moe( a, qweight1, qweight2, @@ -309,7 +309,7 @@ def test_marlin_moe_mmm( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_moe_marlin(a, + marlin_output = single_marlin_moe(a, qweight, scales, score, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 06bd2706d7e4c..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -9,15 +9,15 @@ ] if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) - from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin, single_moe_marlin) __all__ += [ - "fused_moe_marlin", - "single_moe_marlin", + "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py similarity index 99% rename from vllm/model_executor/layers/fused_moe/fused_moe_marlin.py rename to vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 45dead9740f40..5866c83cd9c8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -10,7 +10,7 @@ from vllm.scalar_type import scalar_types -def single_moe_marlin( +def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, @@ -96,7 +96,7 @@ def single_moe_marlin( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) -def fused_moe_marlin( +def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 53769cb73153b..b14ef433d539c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -269,8 +269,8 @@ def apply( custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( - fused_moe_marlin) + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -282,7 +282,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function) - return fused_moe_marlin( + return fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, From 7aa844c8561768190443ebf84ff29021e5d70a9a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 11:03:53 -0400 Subject: [PATCH 57/77] GPTQ Fused MoE class --- .../layers/fused_moe/__init__.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 155 +++++++++++++++++- 2 files changed, 156 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e9b5703ca28be..7f27e2660db65 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,11 +1,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE) from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "GPTQFusedMoE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3df0b61a9ebe4..9643642b9b53e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -498,4 +498,157 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight \ No newline at end of file + param_data[expert_id] = loaded_weight + + +class GPTQFusedMoE(torch.nn.Module): + """GPTQFusedMoE layer for GPTQ MoE models. + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size = intermediate_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + assert (not use_grouped_topk and num_expert_group is None + and topk_group is None) + + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + if ("_qweight" in weight_name or "_scales" in weight_name + or "_qzeros" in weight_name): + if "w13" in weight_name: + shard_size = loaded_weight.size()[-1] + if shard_id == "w1": + param.data[expert_id, :, :shard_size] = loaded_weight + elif shard_id == "w2" or shard_id == "w3": + param.data[expert_id, :, shard_size:] = loaded_weight + else: + raise ValueError(f"Invalid shard_id: {shard_id}: " + "must be w1, w2, or w3.") + elif "w2" in weight_name: + param.data[expert_id][:] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + elif "_g_idx" in weight_name: + if "w13" not in weight_name and "w2" not in weight_name: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + param.data[expert_id] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}.") + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None): + assert (not use_grouped_topk and topk_group is None + and num_expert_group is None) + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=False, + topk_group=False, + num_expert_group=False) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] From 0f7bec3f03f9b7157f237f4dc9e7550ee5487f5f Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 3 Sep 2024 12:41:52 -0400 Subject: [PATCH 58/77] Add GPTQMarlinMoEMethod to gptq_marlin.py --- .../layers/quantization/gptq_marlin.py | 304 +++++++++++++++++- 1 file changed, 289 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b06ff7bd2bace..aac84b4586a83 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,25 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase, + GPTQFusedMoE) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -33,8 +40,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -105,11 +118,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, GPTQFusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,7 +195,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +316,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +326,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +348,259 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=None) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, + ) From cb0001e1ca3f4637f6629925dbca15d361e048bb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 09:00:15 -0400 Subject: [PATCH 59/77] Use FusedMoE layer for all loads --- .../layers/fused_moe/__init__.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 172 ++---------------- .../layers/quantization/gptq_marlin.py | 5 +- 3 files changed, 22 insertions(+), 158 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 7f27e2660db65..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,12 +1,11 @@ from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE) + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "GPTQFusedMoE", ] if HAS_TRITON: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9643642b9b53e..b0d7d4b538df3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -334,6 +334,25 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim + # GPTQ Values + if ("scales" in weight_name or "qweight" in weight_name + or "qzeros" in weight_name): + if (shard_id == "w1" or shard_id == "w3"): + shard_dim = 1 - shard_dim + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + if "g_idx" in weight_name: + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + # Case weight_scales if "weight_scale" in weight_name: # load the weight scaling based on the quantization scheme @@ -499,156 +518,3 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, # If we are in the row parallel case (down_proj) else: param_data[expert_id] = loaded_weight - - -class GPTQFusedMoE(torch.nn.Module): - """GPTQFusedMoE layer for GPTQ MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / - w13) and RowParallelLinear weights (down_proj/ w2). - Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We - copy that naming convention here and handle any remapping in the - load_weights function in each model implementation. - Args: - num_experts: Number of experts in the model - top_k: Number of experts selected for each token - hidden_size: Input hidden state size of the transformer - intermediate_size: Intermediate size of the experts - params_dtype: Data type for the parameters. - reduce_results: Whether to all all_reduce on the output of the layer - renomalize: Whether to renormalize the logits in the fused_moe kernel - quant_config: Quantization configure. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = "", - ): - super().__init__() - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - self.top_k = top_k - self.num_experts = num_experts - self.intermediate_size = intermediate_size - self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results - self.renormalize = renormalize - assert (not use_grouped_topk and num_expert_group is None - and topk_group is None) - - if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedFusedMoEMethod() - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None - - self.quant_method.create_weights( - layer=self, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size_per_partition, - params_dtype=params_dtype, - weight_loader=self.weight_loader, - ) - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: - - if ("_qweight" in weight_name or "_scales" in weight_name - or "_qzeros" in weight_name): - if "w13" in weight_name: - shard_size = loaded_weight.size()[-1] - if shard_id == "w1": - param.data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == "w2" or shard_id == "w3": - param.data[expert_id, :, shard_size:] = loaded_weight - else: - raise ValueError(f"Invalid shard_id: {shard_id}: " - "must be w1, w2, or w3.") - elif "w2" in weight_name: - param.data[expert_id][:] = loaded_weight - else: - raise ValueError(f"Invalid weight name: {weight_name}: " - "must contain 'w13' or 'w2'.") - elif "_g_idx" in weight_name: - if "w13" not in weight_name and "w2" not in weight_name: - raise ValueError(f"Invalid weight name: {weight_name}: " - "must contain 'w13' or 'w2'.") - param.data[expert_id] = loaded_weight - else: - raise ValueError(f"Invalid weight name: {weight_name}.") - - @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None): - assert (not use_grouped_topk and topk_group is None - and num_expert_group is None) - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - - topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) - - return topk_weights, topk_ids - - def forward(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - assert self.quant_method is not None - - # Matrix multiply. - final_hidden_states = self.quant_method.apply( - layer=self, - x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=False, - topk_group=False, - num_expert_group=False) - - if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states - - @classmethod - def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index aac84b4586a83..698a4c29d7a01 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -8,8 +8,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase, - GPTQFusedMoE) + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -124,7 +123,7 @@ def get_quant_method( if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) - elif isinstance(layer, GPTQFusedMoE): + elif isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) return None From 33090a3f93c07e302cc6ef5960f4cad723f808c1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 11:25:08 -0400 Subject: [PATCH 60/77] Make sure that GPTQ runs through mixtral.py --- vllm/model_executor/layers/quantization/gptq_marlin.py | 6 +++--- vllm/model_executor/model_loader/utils.py | 2 +- vllm/model_executor/models/mixtral.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 698a4c29d7a01..b0f972182d598 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Parameter @@ -551,8 +551,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, - size_k=(layer.intermediate_size if self.quant_config.desc_act else - layer.intermediate_size_per_partition), + size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -575,6 +574,7 @@ def apply( use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 4bb943ab3afe4..d247e4cf3f07b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e744e36ac08bf..6413b56605ecf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -435,7 +435,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith("bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,6 +454,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if name.endswith("bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -464,7 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith("bias") and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From d4798373c1b861aee79d665fbe8a56d945da9a42 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 12:40:52 -0400 Subject: [PATCH 61/77] enforce float16A/scales for marlin moe --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- vllm/model_executor/models/mixtral.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b0f972182d598..b53267c0bd06e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -404,7 +404,7 @@ def create_weights( torch.empty(num_experts, scales_size13, 2 * intermediate_size, - dtype=params_dtype), + dtype=torch.half), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) @@ -414,7 +414,7 @@ def create_weights( torch.empty(num_experts, scales_size2, hidden_size, - dtype=params_dtype), + dtype=torch.half), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 6413b56605ecf..148ef393277e4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,11 +95,12 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape + orig_dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape) + final_hidden_states = self.experts(hidden_states.half(), router_logits) + return final_hidden_states.view(orig_shape).to(orig_dtype) class MixtralAttention(nn.Module): From 8baaec644b2468e263f14022b01c8b55d3893ad6 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 4 Sep 2024 15:28:23 +0000 Subject: [PATCH 62/77] remove large model --- tests/weight_loading/models.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 5eee2cc534445..1dc529037a98e 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,7 +21,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main From 8fbc181dfa2747a8d5dbf03ef207c9b163a68c75 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 13:10:02 -0400 Subject: [PATCH 63/77] Cleanup, comments --- csrc/moe/marlin_moe_ops.cu | 4 +- tests/kernels/test_moe.py | 1 - .../layers/fused_moe/__init__.py | 8 +-- .../layers/fused_moe/fused_marlin_moe.py | 50 ++++++++----------- .../compressed_tensors_moe.py | 1 - 5 files changed, 28 insertions(+), 36 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e3c18ce5a50b8..f6d475a56851f 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1228,8 +1228,6 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { @@ -1237,6 +1235,8 @@ __device__ inline void MarlinMoESingle( } cp_async_fence(); } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 6069978439824..7e359ff08088c 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -241,7 +241,6 @@ def test_fused_marlin_moe( sort_indices2, topk_weights, topk_ids, - renormalize=False, w1_scale=scales1, w2_scale=scales2, num_bits=num_bits, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index e9b5703ca28be..dea4a32aec4f8 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,3 +1,5 @@ +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -6,18 +8,16 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", + "fused_marlin_moe", + "single_marlin_moe", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ - "fused_marlin_moe", - "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 5866c83cd9c8c..c7906205760ff 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,6 +1,6 @@ """Fused MoE utilities for GPTQ.""" import functools -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional import torch @@ -16,11 +16,10 @@ def single_marlin_moe( scales: torch.Tensor, gating_output: torch.Tensor, g_idx: torch.Tensor, - rand_perm: torch.Tensor, + perm: torch.Tensor, topk: int, renormalize: bool, override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, num_bits: int = 8, ) -> torch.Tensor: """ @@ -28,18 +27,18 @@ def single_marlin_moe( and top-k gating mechanism. It is meant for testing and debugging. Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w (torch.Tensor): The first set of expert weights. + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. - gating_output (torch.Tensor): The output of the gating operation (before softmax). + - g_idx (torch.Tensor): The act_order indices. + - perm (torch.Tensor): The act_order input permutation. - topk (int): The number of top-k experts to select. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - product for w. Defaults to False. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -55,8 +54,6 @@ def single_marlin_moe( torch.float32, torch.float16, torch.bfloat16 ] assert num_bits in [4, 8] - # TODO support this - assert not use_fp8 M, K = hidden_states.shape E = w.shape[0] @@ -70,7 +67,7 @@ def single_marlin_moe( w.shape, w.shape, topk_ids.shape[1], - "float8" if use_fp8 else None, + None, override_config=override_config, is_marlin=True) config = get_config_func(M) @@ -90,7 +87,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, rand_perm, workspace, scalar_type, M, N, K, True, E, topk, + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -103,14 +100,11 @@ def fused_marlin_moe( gating_output: torch.Tensor, g_idx1: torch.Tensor, g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, + perm1: torch.Tensor, + perm2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - custom_routing_function: Optional[Callable] = None, - renormalize: bool = True, override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, num_bits: int = 8, @@ -125,18 +119,20 @@ def fused_marlin_moe( - w2 (torch.Tensor): The second set of expert weights. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - topk (int): The number of top-k experts to select. + - g_idx1 (torch.Tensor): The fist set of act_order indices. + - g_idx2 (torch.Tensor): The second set of act_order indices. + - perm1 (torch.Tensor): The first act_order input permutation. + - perm2 (torch.Tensor): The second act_order input permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -156,8 +152,6 @@ def fused_marlin_moe( torch.float32, torch.float16, torch.bfloat16 ] assert num_bits in [4, 8] - # TODO support this - assert not use_fp8 M, K = hidden_states.shape E = w1.shape[0] @@ -169,7 +163,7 @@ def fused_marlin_moe( w1.shape, w2.shape, topk_ids.shape[1], - "float8" if use_fp8 else None, + None, override_config=override_config, is_marlin=True, ) @@ -202,7 +196,7 @@ def fused_marlin_moe( topk_ids, w1_scale, g_idx1, - rand_perm1, + perm1, workspace, scalar_type, M, @@ -226,7 +220,7 @@ def fused_marlin_moe( topk_ids, w2_scale, g_idx2, - rand_perm2, + perm2, workspace, scalar_type, M, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b14ef433d539c..7dee2fca81153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -293,7 +293,6 @@ def apply( layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - renormalize=renormalize, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, num_bits=self.num_bits, From 839915f285fcc09dff376b11735e1e828da3e924 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 4 Sep 2024 13:13:32 -0400 Subject: [PATCH 64/77] cleanup --- vllm/model_executor/layers/quantization/gptq_marlin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b53267c0bd06e..d593298cf2f12 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -598,7 +598,6 @@ def apply( layer.w2_g_idx_sort_indices, topk_weights, topk_ids, - renormalize=renormalize, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, num_bits=self.quant_config.quant_type.size_bits, From a5bc626e59fd755baf96a65cd6b68b136fd7e2f0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 6 Sep 2024 03:12:10 -0400 Subject: [PATCH 65/77] remove 8-bit stuff for now --- csrc/moe/marlin_moe_ops.cu | 303 ++++++------------ csrc/moe/marlin_moe_ops.h | 7 +- csrc/moe/torch_bindings.cpp | 8 +- tests/kernels/test_moe.py | 14 +- vllm/_custom_ops.py | 2 +- .../layers/fused_moe/__init__.py | 8 +- .../layers/fused_moe/fused_marlin_moe.py | 52 +-- .../compressed_tensors_moe.py | 1 - .../schemes/compressed_tensors_wNa16.py | 1 - .../layers/quantization/gptq_marlin.py | 1 - 10 files changed, 120 insertions(+), 277 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index f6d475a56851f..92184f43c9eb0 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,8 +25,6 @@ #include -#include "core/scalar_type.hpp" - template inline std::string str(T x) { return std::to_string(x); @@ -133,26 +131,11 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -173,28 +156,6 @@ __device__ inline FragB dequant(int q) { return frag_b; } -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -335,8 +296,7 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; bool is_same_group[stages]; @@ -893,19 +840,10 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); + FragB frag_b0 = dequant(b_quant); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -917,6 +855,8 @@ __device__ inline void MarlinMoESingle( } } + FragB frag_b1 = dequant(b_quant_shift); + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -941,13 +881,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; + constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1095,10 +1035,8 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { res = __hmul2(res, s[0]); } @@ -1228,70 +1166,28 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { + if (last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { + if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } } } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1331,8 +1227,7 @@ __device__ inline void MarlinMoESingle( } } -template ( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1447,8 +1342,7 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ @@ -1601,43 +1494,42 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1719,13 +1611,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } - int pack_factor = 32 / q_type.size_bits(); - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = - (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1756,14 +1645,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1785,15 +1670,9 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - - int pack_factor = 32 / b_q_type->size_bits(); - int max_par = 4; int dev = a.get_device(); @@ -1854,8 +1733,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, - topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + has_act_order, is_k_full, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index adee8399a4d6f..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,14 +2,11 @@ #include -#include "core/scalar_type.hpp" - torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cd65a8ee92b94..8a0e625b43fa1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,11 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " - "int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " + "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " + "bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7e359ff08088c..2250cf1598b8b 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,7 +140,6 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -149,7 +148,6 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, - num_bits: int, ): torch.manual_seed(7) @@ -163,8 +161,7 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = scalar_types.uint4b8 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -243,7 +240,6 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, - num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -258,7 +254,6 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_marlin_moe_mmm( m: int, n: int, @@ -267,7 +262,6 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, - num_bits: int, ): if topk > e: return @@ -279,8 +273,7 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = scalar_types.uint4b8 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -315,8 +308,7 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False, - num_bits=num_bits) + renormalize=False) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 77c46584ef530..151cdbee8eb04 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -308,7 +308,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + output = torch.empty((num_experts, size_k // 16, size_n * 2), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index dea4a32aec4f8..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,3 @@ -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON @@ -8,16 +6,18 @@ "FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported", - "fused_marlin_moe", - "single_marlin_moe", ] if HAS_TRITON: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c7906205760ff..6b01ec0a623aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,21 +7,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) -from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - num_bits: int = 8, -) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: """ This function computes a Marlin MoE MMM using weights w and top-k gating mechanism. It is meant for testing and debugging. @@ -38,7 +35,6 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. - - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -50,14 +46,11 @@ def single_marlin_moe( assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert num_bits in [4, 8] + assert hidden_states.dtype == torch.float16 M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // (num_bits // 2) + N = w.shape[2] // 2 topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -82,13 +75,10 @@ def single_marlin_moe( device="cuda", requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, - block_size_m, True, False) + g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, + False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -107,7 +97,6 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -132,7 +121,6 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -142,16 +130,13 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // ( - num_bits // 2), "Hidden size mismatch w2" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert num_bits in [4, 8] + assert hidden_states.dtype == torch.float16 M, K = hidden_states.shape E = w1.shape[0] @@ -179,9 +164,6 @@ def fused_marlin_moe( device="cuda", requires_grad=False) - scalar_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) - intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -198,7 +180,6 @@ def fused_marlin_moe( g_idx1, perm1, workspace, - scalar_type, M, 2 * N, K, @@ -222,7 +203,6 @@ def fused_marlin_moe( g_idx2, perm2, workspace, - scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7dee2fca81153..f8a41dfd08d73 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -295,5 +295,4 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, - num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..e3b74e8712903 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -18,7 +18,6 @@ __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d593298cf2f12..15c0a570c4caf 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -600,5 +600,4 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, - num_bits=self.quant_config.quant_type.size_bits, ) From c573fa1b084e789dd4821fa020efac84f5574a17 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 5 Sep 2024 21:07:45 +0000 Subject: [PATCH 66/77] update/fix weight loading to support tp --- vllm/model_executor/layers/fused_moe/layer.py | 80 ++++++++++--------- .../layers/quantization/gptq_marlin.py | 11 ++- 2 files changed, 53 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b0d7d4b538df3..f4621e5c4ccc4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -306,10 +306,28 @@ def _load_single_value(self, param: torch.nn.Parameter, # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # llm-compressor returns weights on disk which are flipped + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsMoEMethod") else loaded_weight + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") @@ -325,38 +343,41 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # is_transposed: whether or not the parameter is transposed on disk - # If transposed, the loaded weight will be transposed and the dim - # to shard the loaded weight will be flipped. + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim - # GPTQ Values - if ("scales" in weight_name or "qweight" in weight_name - or "qzeros" in weight_name): - if (shard_id == "w1" or shard_id == "w3"): - shard_dim = 1 - shard_dim - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - return + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") - if "g_idx" in weight_name: self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) return - # Case weight_scales - if "weight_scale" in weight_name: - # load the weight scaling based on the quantization scheme - # supported weight scales can be found in + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + # Case weight scales and zero_points + if ("scale" in weight_name or "zero" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case @@ -385,22 +406,9 @@ def weight_loader(self, param: torch.nn.Parameter, f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return + # Case weight_shape if "weight_shape" in weight_name: - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return - - # Case input scale - if "input_scale" in weight_name: - # Note: input_scale loading is only supported for fp8 - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - + # only required by compressed-tensors self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 15c0a570c4caf..0a470f311c746 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -7,8 +7,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( @@ -372,9 +372,16 @@ def create_weights( if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size scales_size2 = intermediate_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( torch.empty( From a991d828a6c688eeb3d87db4cf7651510c447e65 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 15:13:54 +0000 Subject: [PATCH 67/77] fix; update large model testing cases --- .buildkite/test-pipeline.yaml | 13 ++++++++++++- tests/weight_loading/models-large.txt | 3 +++ tests/weight_loading/models.txt | 2 -- .../compressed_tensors/compressed_tensors_moe.py | 7 ++----- .../schemes/compressed_tensors_wNa16.py | 1 + 5 files changed, 18 insertions(+), 8 deletions(-) create mode 100644 tests/weight_loading/models-large.txt diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d0317b2fc48c9..a0c7b7442b3b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -386,7 +386,18 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt new file mode 100644 index 0000000000000..fe76705746766 --- /dev/null +++ b/tests/weight_loading/models-large.txt @@ -0,0 +1,3 @@ +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98e..a3e382acf56b3 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f8a41dfd08d73..49c29c2775cb6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,8 +6,6 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -40,11 +38,10 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): + and self.num_bits == 4): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + "is supported for 4 bits") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index e3b74e8712903..cae6ffad53df1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -18,6 +18,7 @@ __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) From d57804d96b036a2916cfd872d9e8ea3889442051 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 16:13:55 +0000 Subject: [PATCH 68/77] add hack to support unfused mixtral pathway for int8 --- vllm/model_executor/model_loader/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index d247e4cf3f07b..0052489d99dc4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,19 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors"] + # for gptq_marlin, only run fused MoE for int4 + if model_config.quantization == "gptq_marlin": + hf_quant_config = getattr(model_config.hf_config, + "quantization_config", None) + if hf_quant_config and hf_quant_config.get("bits") == 4: + mixtral_supported.append("gptq_marlin") + if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + return ModelRegistry.resolve_model_cls(architectures) From 96fa486336d28e741c392277d06232ec6a0eed17 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 6 Sep 2024 18:29:36 +0000 Subject: [PATCH 69/77] fix install for tpu test --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0a470f311c746..3bc35dca5d03d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -5,8 +5,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -583,6 +581,8 @@ def apply( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, From 1faab903a378361275738c04b2dd394067153f20 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sat, 7 Sep 2024 06:30:32 -0400 Subject: [PATCH 70/77] Move float16 typecast hack to gptq marlin moe method --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 +++ vllm/model_executor/models/mixtral.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3bc35dca5d03d..a01d5fe65538a 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -584,6 +584,9 @@ def apply( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) + # The input must currently be float16 + x = x.half() + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 148ef393277e4..df7f39097bdc6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -99,7 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states.half(), router_logits) + final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape).to(orig_dtype) From 970e06a77a02953e43a59c5683891dcd968f4f14 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Sat, 7 Sep 2024 06:58:44 -0400 Subject: [PATCH 71/77] Move output type conversion to gptq method as well --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 ++- vllm/model_executor/models/mixtral.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a01d5fe65538a..3617a32f80fc1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -585,6 +585,7 @@ def apply( fused_marlin_moe) # The input must currently be float16 + orig_dtype = x.dtype x = x.half() topk_weights, topk_ids = FusedMoE.select_experts( @@ -610,4 +611,4 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, - ) + ).to(orig_dtype) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index df7f39097bdc6..6413b56605ecf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,12 +95,11 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape - orig_dtype = hidden_states.dtype hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape).to(orig_dtype) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): From fd0a4f2b2f1627330641645e2891ad4655e1f0b5 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 9 Sep 2024 01:48:38 +0000 Subject: [PATCH 72/77] typo fix; fix comment --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 6b01ec0a623aa..3639350d850e5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -108,7 +108,7 @@ def fused_marlin_moe( - w2 (torch.Tensor): The second set of expert weights. - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - g_idx1 (torch.Tensor): The fist set of act_order indices. + - g_idx1 (torch.Tensor): The first set of act_order indices. - g_idx2 (torch.Tensor): The second set of act_order indices. - perm1 (torch.Tensor): The first act_order input permutation. - perm2 (torch.Tensor): The second act_order input permutation. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f4621e5c4ccc4..f6c6f5f529408 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -323,7 +323,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - # llm-compressor returns weights on disk which are flipped + # compressed-tensors represents weights on disk which are flipped loaded_weight = loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ == "CompressedTensorsMoEMethod") else loaded_weight From d51a2f43b1a63fb592bfea90f224a10727c3dca3 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 9 Sep 2024 06:56:20 -0400 Subject: [PATCH 73/77] Clarify comment, change how we process bias --- vllm/model_executor/layers/fused_moe/fused_marlin_moe.py | 5 +++-- vllm/model_executor/models/mixtral.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3639350d850e5..200a6148978aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -20,8 +20,9 @@ def single_marlin_moe( renormalize: bool, override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: """ - This function computes a Marlin MoE MMM using weights w - and top-k gating mechanism. It is meant for testing and debugging. + This function computes the multiplication of hidden_states with expert + weights used in Marlin MoE, using weights w and top-k gating mechanism. + Its purpose is testing and debugging the fused MoE kernel. Parameters: - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 6413b56605ecf..10cbfcf6432b3 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -435,7 +435,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith("bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,7 +455,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if name.endswith("bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue param = params_dict[name] weight_loader = param.weight_loader @@ -466,7 +468,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith("bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): From 08287ef6751e79a89bf4f060f5f9545560a6de12 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 9 Sep 2024 09:45:11 -0500 Subject: [PATCH 74/77] [Bugfix] Streamed tool calls now more strictly follow OpenAI's format; ensures Vercel AI SDK compatibility (#8272) --- tests/tool_use/utils.py | 2 +- vllm/entrypoints/openai/protocol.py | 7 ----- vllm/entrypoints/openai/serving_chat.py | 6 ++++- .../tool_parsers/abstract_tool_parser.py | 1 - .../openai/tool_parsers/hermes_tool_parser.py | 20 ++++---------- .../tool_parsers/mistral_tool_parser.py | 27 ++++++------------- 6 files changed, 19 insertions(+), 44 deletions(-) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 8ec9b05b2c521..e447469e33410 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -19,7 +19,7 @@ class ServerConfig(TypedDict): CONFIGS: Dict[str, ServerConfig] = { "hermes": { "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", "arguments": [ "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 970262a4bd358..374196044b7e8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -713,13 +713,6 @@ class DeltaToolCall(OpenAIBaseModel): function: Optional[DeltaFunctionCall] = None -# the initial delta that gets sent once a new tool call is started; -class InitialDeltaToolCall(DeltaToolCall): - id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") - type: Literal["function"] = "function" - index: int - - class ExtractedToolCallInformation(BaseModel): # indicate if tools were called tools_called: bool diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 78f355228012f..8ed81e9c88cb2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -271,9 +271,13 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(role=role), + delta=DeltaMessage( + role=role, + content="", + ), logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse( diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index b0807e6f1e782..873f615d43257 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -20,7 +20,6 @@ def __init__(self, tokenizer: AnyTokenizer): # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [] self.model_tokenizer = tokenizer diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 7afbca7162edf..bde9b47ce60d5 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -8,14 +8,14 @@ from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - InitialDeltaToolCall, ToolCall) + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid logger = init_logger(__name__) @@ -34,7 +34,6 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list @@ -168,7 +167,6 @@ def extract_tool_calls_streaming( # set cursors and state appropriately self.current_tool_id += 1 self.current_tool_name_sent = False - self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") logger.debug("Starting on a new tool %s", self.current_tool_id) @@ -218,24 +216,16 @@ def extract_tool_calls_streaming( logger.debug('not enough tokens to parse into JSON yet') return None - # case - we haven't sent the initial delta with the tool call ID - # (it will be sent) - if not self.current_tool_initial_sent: - self.current_tool_initial_sent = True - return DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. - elif not self.current_tool_name_sent: + if not self.current_tool_name_sent: function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index d48770c792e98..4b0e1c91df97c 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -8,14 +8,14 @@ from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - InitialDeltaToolCall, ToolCall) + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid logger = init_logger(__name__) @@ -25,7 +25,7 @@ class MistralToolParser(ToolParser): Tool call parser for Mistral 7B Instruct v0.3, intended for use with the examples/tool_chat_template_mistral.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set """ def __init__(self, tokenizer: AnyTokenizer): @@ -42,7 +42,6 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" @@ -91,7 +90,6 @@ def extract_tool_calls(self, except Exception as e: logger.error("Error in extracting tool call from response: %s", e) - print("ERROR", e) # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, tool_calls=[], @@ -109,7 +107,7 @@ def extract_tool_calls_streaming( # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool - if self.bot_token_id not in current_token_ids: + if self.bot_token not in current_text: return DeltaMessage(content=delta_text) # if the tool call token ID IS in the tokens generated so far, that @@ -134,7 +132,7 @@ def extract_tool_calls_streaming( # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[1] + parsable_arr = current_text.split(self.bot_token)[-1] # tool calls are generated in an array, so do partial JSON # parsing on the entire array @@ -186,31 +184,22 @@ def extract_tool_calls_streaming( # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False - self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") logger.debug("starting on new tool %d", self.current_tool_id) return delta # case: update an existing tool - this is handled below - # if the current tool initial data incl. the id, type=function - # and idx not sent, send that - if not self.current_tool_initial_sent: - self.current_tool_initial_sent = True - delta = DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - # if the current tool name hasn't been sent, send if available # - otherwise send nothing - elif not self.current_tool_name_sent: + if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) From 58fcc8545a149c9c5b1f91f417a68f5ba1fdabf3 Mon Sep 17 00:00:00 2001 From: Adam Lugowski Date: Mon, 9 Sep 2024 11:16:37 -0700 Subject: [PATCH 75/77] [Frontend] Add progress reporting to run_batch.py (#8060) Co-authored-by: Adam Lugowski --- vllm/entrypoints/openai/run_batch.py | 54 ++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 32bbade256973..278be8cd11a12 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -1,9 +1,11 @@ import asyncio from io import StringIO -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import aiohttp +import torch from prometheus_client import start_http_server +from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -78,6 +80,38 @@ def parse_args(): return parser.parse_args() +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ @@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None: async def run_request(serving_engine_func: Callable, - request: BatchRequestInput) -> BatchRequestOutput: + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): @@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable, else: raise ValueError("Request must not be sent in stream mode") + tracker.completed() return batch_output @@ -164,6 +200,9 @@ async def main(args): request_logger=request_logger, ) + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + # Submit all requests in the file to the engine "concurrently". response_futures: List[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): @@ -178,16 +217,19 @@ async def main(args): if request.url == "/v1/chat/completions": response_futures.append( run_request(openai_serving_chat.create_chat_completion, - request)) + request, tracker)) + tracker.submitted() elif request.url == "/v1/embeddings": response_futures.append( - run_request(openai_serving_embedding.create_embedding, - request)) + run_request(openai_serving_embedding.create_embedding, request, + tracker)) + tracker.submitted() else: raise ValueError("Only /v1/chat/completions and /v1/embeddings are" "supported in the batch endpoint.") - responses = await asyncio.gather(*response_futures) + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) output_buffer = StringIO() for response in responses: From f9b4a2d41587da0692d32797221df55a02d890a6 Mon Sep 17 00:00:00 2001 From: Vladislav Kruglikov Date: Mon, 9 Sep 2024 21:20:46 +0300 Subject: [PATCH 76/77] [Bugfix] Correct adapter usage for cohere and jamba (#8292) --- vllm/model_executor/models/commandr.py | 5 +++-- vllm/model_executor/models/jamba.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index be7f19d15b623..649dc798d22dc 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -47,6 +47,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from .interfaces import SupportsLoRA + @torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): @@ -292,8 +294,7 @@ def forward( return hidden_states -class CohereForCausalLM(nn.Module): - +class CohereForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 73be7ffed0f89..29dd09afac5ad 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -38,6 +38,8 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) +from .interfaces import SupportsLoRA + KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -539,7 +541,7 @@ def forward( return hidden_states -class JambaForCausalLM(nn.Module, HasInnerState): +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", From c7cb5c333564cb00fc4f6a99d32c35e9ebc0f1ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Sep 2024 16:27:26 -0400 Subject: [PATCH 77/77] [Misc] GPTQ Activation Ordering (#8135) --- tests/weight_loading/models.txt | 1 + .../compressed_tensors/compressed_tensors.py | 3 +- .../schemes/compressed_tensors_wNa16.py | 45 ++++++++++++++----- .../quantization/compressed_tensors/utils.py | 30 ++++++++++++- 4 files changed, 64 insertions(+), 15 deletions(-) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 1dc529037a98e..c708e6d5eb897 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0768b37044aac..1170d55f31993 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -232,7 +232,8 @@ def _get_scheme_from_parts( return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) # Detect If Activation Quantization. # TODO @dsikka: clean-up conditions diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..8897737c1c55a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -5,14 +5,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + ActivationOrdering) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, - PackedvLLMParameter) + PackedvLLMParameter, + RowvLLMParameter) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] @@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): def __init__(self, strategy: str, num_bits: int, - group_size: Optional[int] = None): + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": raise ValueError("Marlin kernels require group quantization or " @@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. - channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) - # In the case of channelwise quantization, we need to replicate the - # scales across all gpus. - partition_scales = (row_parallel and not channelwise) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, @@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size @@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # Handle sorting for activation reordering if needed. + if self.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "weight_g_idx", g_idx) + else: + layer.weight_g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) # No zero-point layer.weight_zp = marlin_make_empty_g_idx(device) @@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. + # scale is required on all partitions if activation reordering marlin_scales = marlin_permute_scales( layer.weight_scale, - size_k=layer.input_size_per_partition, + size_k=(layer.input_size + if self.has_g_idx else layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=layer.group_size) replace_tensor(layer, "weight_scale", marlin_scales) @@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, weight=layer.weight_packed, weight_scale=layer.weight_scale, weight_zp=layer.weight_zp, - g_idx=layer.g_idx, + g_idx=layer.weight_g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, wtype=self.quant_type, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 7912cbde5721f..fc531b9d666e3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,8 +1,8 @@ import re from enum import Enum -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Module from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + + GROUP = "group" + WEIGHT = "weight" + + class QuantizationArgs(BaseModel): """ User facing arguments used to define a quantization config @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): observed with every sample. Defaults to False for static quantization. Note that enabling dynamic quantization will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering """ num_bits: int = 8 @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None observer: str = Field( default="minmax", description=("The class to use to compute the quantization param - " @@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel): "Observers constructor excluding quantization range or symmetry"), ) + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + + if isinstance(value, str): + return ActivationOrdering(value.lower()) + + return value + def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [