From e5c1a8131c970fbb42540b518c8e37d3d0b150e8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 6 Aug 2024 17:20:26 -0700 Subject: [PATCH 01/96] Refactoring for maintainability --- .../layers/fused_moe/__init__.py | 18 +- .../layers/fused_moe/fused_moe.py | 102 +--- .../layers/fused_moe/fused_moe_gptq.py | 138 +++++ vllm/model_executor/layers/fused_moe/layer.py | 482 ++++++------------ .../layers/quantization/gptq_marlin.py | 356 ++++++++++++- vllm/model_executor/models/mixtral_quant.py | 144 ++---- 6 files changed, 665 insertions(+), 575 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_gptq.py diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 080ecb5cfe0ba..2b982b7ab9f86 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,21 +1,23 @@ -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe, - single_marlin_moe) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", - "fused_marlin_moe", + "fused_moe_gptq", "single_marlin_moe", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + fused_experts, + fused_moe, + fused_topk, + get_config_file_name, + grouped_topk, + ) __all__ += [ "fused_moe", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 64e47ad803232..9ae5859c4da0c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -704,104 +704,4 @@ def single_marlin_moe( g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, True, False) - return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) - - -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py new file mode 100644 index 0000000000000..15c11fc0b668e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py @@ -0,0 +1,138 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +import torch + +from typing import Any, Dict, Optional +from vllm import _custom_ops as ops + +from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config + + +def fused_moe_gptq( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False + ) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 564a316b4894a..913d6a93b0cd5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,16 +6,17 @@ import torch from vllm import _custom_ops as ops -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.fused_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -24,300 +25,63 @@ class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: raise NotImplementedError -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -class MarlinFusedMoEMethod(FusedMoEMethodBase): - """MoE Marlin method with quantization.""" - - def __init__(self, quant_config: GPTQMarlinConfig) -> None: - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) - if self.quant_config.group_size != -1: - scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = intermediate_size // self.quant_config.group_size - else: - scales_size13 = 1 - scales_size2 = 1 - # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.quant_config.pack_factor, - 2 * intermediate_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size // self.quant_config.pack_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - # up_proj scales - w13_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size13, - 2 * intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - # down_proj scales - w2_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - layer.marlin_state = GPTQMarlinState.REPACK - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - num_experts = layer.w13_g_idx.shape[0] - w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_g_idx[e]).to(torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] - replace_tensor("w13_g_idx", w13_sorted_g_idx) - replace_tensor("w2_g_idx", w2_sorted_g_idx) - replace_tensor("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - else: - # Reset g_idx related tensors - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w2_qweight", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_scales, - x.shape[1], - layer.w13_scales.shape[2], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w13_scales", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_scales, - layer.w2_scales.shape[1] * self.quant_config.pack_factor, - x.shape[1], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w2_scales", marlin_w2_scales) - return fused_marlin_moe(x, - layer.w13_qweight, - layer.w2_qweight, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales) - - class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -332,9 +96,17 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, ) -> torch.Tensor: - return self.forward(x, layer.w13_weight, layer.w2_weight, - router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) + return self.forward( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + ) def forward_cuda( self, @@ -349,16 +121,19 @@ def forward_cuda( topk_group: Optional[int], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe - return fused_moe(x, - w1, - w2, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + + return fused_moe( + x, + w1, + w2, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -377,6 +152,7 @@ def forward_tpu( topk_group: Optional[int], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -386,7 +162,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -438,12 +214,9 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group - self.quant_method: Optional[QuantizeMethodBase] = None - if quant_config is None: - self.quant_method = UnquantizedFusedMoEMethod() - elif isinstance(quant_config, GPTQMarlinConfig): - self.quant_method = MarlinFusedMoEMethod(quant_config) + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None @@ -454,15 +227,18 @@ def __init__( hidden_size=hidden_size, intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, - weight_loader=self.weight_loader) - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: int, - expert_id: int, - is_quantized: bool = False): + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + is_quantized: bool = False, + ): param_data = param.data if is_quantized: @@ -491,8 +267,8 @@ def weight_loader(self, else: # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: + if (param_data[expert_id] != 1 and + (param_data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -546,7 +322,8 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, num_expert_group=self.num_expert_group, - topk_group=self.topk_group) + topk_group=self.topk_group, + ) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -556,37 +333,70 @@ def forward(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: - + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, int]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] gate_down_up = [ ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name ] - return [ + return ([ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in gate_up else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(num_experts) + ( + "experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" - if weight_name in gate_up else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(num_experts) + ( + "experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" - if weight_name in gate_up else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(num_experts) + ( + "experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) - ] + ] + [ + # These are the qweights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qweight" + if weight_name in gate_up else "experts.w2_qweight", + f"experts.{expert_id}.{weight_name}.qweight", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the g_idx and g_idx_sort_indices scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_g_idx" + if weight_name in gate_up else "experts.w2_g_idx", + f"experts.{expert_id}.{weight_name}.g_idx", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ]) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bdcc9c3b4f0c5..f58a89c8e4bb9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,29 +1,53 @@ -from typing import Any, Dict, List, Optional - +from typing import Any, Dict, List, Optional, Union +import enum +from enum import Enum import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_gptq_marlin_supported, verify_marlin_supports_shape) + apply_gptq_marlin_linear, + check_gptq_marlin_supported, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, + replace_tensor, + verify_gptq_marlin_supported, + verify_marlin_supports_shape, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -95,11 +119,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -118,15 +145,15 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False return check_gptq_marlin_supported( num_bits=num_bits, group_size=group_size, is_sym=sym, - min_capability=cls.get_min_capability()) + min_capability=cls.get_min_capability(), + ) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -163,7 +190,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -293,7 +321,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -302,7 +331,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -323,4 +353,284 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + layer.marlin_state = GPTQMarlinState.REPACK + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int, + ): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int, + ): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor("w13_g_idx", w13_sorted_g_idx) + replace_tensor("w2_g_idx", w2_sorted_g_idx) + replace_tensor("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor("w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor("w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_scales, + x.shape[1], + layer.w13_scales.shape[2], + self.quant_config.group_size, + self.quant_config.weight_bits, + ) + replace_tensor("w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_scales, + layer.w2_scales.shape[1] * self.quant_config.pack_factor, + x.shape[1], + self.quant_config.group_size, + self.quant_config.weight_bits, + ) + replace_tensor("w2_scales", marlin_w2_scales) + return fused_moe_gptq( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + ) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 85dafd55bbcf8..cdfd24874b974 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -35,7 +34,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -96,13 +94,10 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.use_fused_moe = use_fused_moe - self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts @@ -118,26 +113,14 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - if self.use_fused_moe: - params_dtype = torch.float16 - self.experts = FusedMoE(num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size) - else: - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, @@ -149,34 +132,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits, _ = self.gate(hidden_states) - if self.use_fused_moe: - ret = self.experts(hidden_states.half(), router_logits) - return ret.bfloat16() - else: - routing_weights = F.softmax(router_logits, - dim=1, - dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) class MixtralAttention(nn.Module): @@ -261,7 +238,6 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -278,7 +254,6 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -319,7 +294,6 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -333,7 +307,6 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, - use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -370,12 +343,10 @@ def __init__( super().__init__() # TODO check runs with dtype=float16 - self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, self.use_fused_moe, cache_config, - quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -436,50 +407,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - - param = params_dict[name] - - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 7da678eb61bcd50c0b51d30d899e319d9e255d5a Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 6 Aug 2024 18:18:26 -0700 Subject: [PATCH 02/96] Fixing tests --- tests/kernels/test_moe.py | 86 ++++++++++--------- .../layers/fused_moe/fused_moe_gptq.py | 26 ++++-- 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e73e5a518ef1a..d9480c8cf882e 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,10 +10,10 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import (fused_marlin_moe, fused_moe, - single_marlin_moe) +from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) + marlin_quantize, ) from vllm.model_executor.models.mixtral import MixtralMoE @@ -62,11 +62,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) @@ -114,10 +114,12 @@ def test_mixtral_moe(dtype: torch.dtype): torch.bfloat16: 1e-2, } - assert torch.allclose(hf_states.flatten(0, 1), - vllm_states, - rtol=mixtral_moe_tol[dtype], - atol=mixtral_moe_tol[dtype]) + assert torch.allclose( + hf_states.flatten(0, 1), + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype], + ) def stack_and_dev(tensors: List[torch.Tensor]): @@ -165,11 +167,11 @@ def test_fused_marlin_moe( num_bits = 4 dtype = torch.float16 - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device='cuda', dtype=dtype) + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -215,27 +217,31 @@ def test_fused_marlin_moe( g_idx2 = stack_and_dev(g_idx2_l) sort_indices2 = stack_and_dev(sort_indices2_l) - score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False) - marlin_output = fused_marlin_moe(a, - qweight1, - qweight2, - score, - g_idx1, - g_idx2, - sort_indices1, - sort_indices2, - topk, - renormalize=False, - w1_scale=scales1, - w2_scale=scales2) - - assert (compute_max_diff(marlin_output, triton_output) < 4e-2) + score = torch.randn((m, e), device="cuda", dtype=dtype) + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_moe_gptq( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk, + renormalize=False, + w1_scale=scales1, + w2_scale=scales2, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 # TODO: make sure this test works @@ -272,8 +278,8 @@ def test_single_marlin_moe( num_bits = 4 dtype = torch.float16 - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w = torch.randn((e, n, k), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] qweights_l = [] @@ -297,7 +303,7 @@ def test_single_marlin_moe( g_idx = stack_and_dev(g_idx_l) sort_indices = stack_and_dev(sort_indices_l) - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = single_marlin_moe(a, qweight, scales, @@ -308,4 +314,4 @@ def test_single_marlin_moe( renormalize=False) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) - assert (compute_max_diff(marlin_output, torch_output) < 1e-2) + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py index 15c11fc0b668e..e7c47f14f85d4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py @@ -51,19 +51,25 @@ def fused_moe_gptq( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -81,9 +87,10 @@ def fused_moe_gptq( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros( - max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False - ) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -135,4 +142,5 @@ def fused_moe_gptq( True, ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) From 641696b8608843e12c5852dd33c7c6322ba5d297 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 8 Aug 2024 10:16:19 -0700 Subject: [PATCH 03/96] Addressing repacking comment --- .../layers/quantization/gptq_marlin.py | 221 +++++++----------- .../layers/quantization/utils/marlin_utils.py | 17 ++ 2 files changed, 98 insertions(+), 140 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index f58a89c8e4bb9..90ffbd4360f9b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -21,6 +21,7 @@ marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, @@ -469,6 +470,86 @@ def create_weights( set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.marlin_state = GPTQMarlinState.REPACK + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.marlin_state = GPTQMarlinState.READY + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2]) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), + group_size=self.quant_config.group_size) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + def apply( self, layer: torch.nn.Module, @@ -480,146 +561,6 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, ) -> torch.Tensor: - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, - num_bits: int, - ): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, - num_bits: int, - ): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - num_experts = layer.w13_g_idx.shape[0] - w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_g_idx[e]).to(torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] - replace_tensor("w13_g_idx", w13_sorted_g_idx) - replace_tensor("w2_g_idx", w2_sorted_g_idx) - replace_tensor("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - else: - # Reset g_idx related tensors - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w2_qweight", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_scales, - x.shape[1], - layer.w13_scales.shape[2], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w13_scales", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_scales, - layer.w2_scales.shape[1] * self.quant_config.pack_factor, - x.shape[1], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w2_scales", marlin_w2_scales) return fused_moe_gptq( x, layer.w13_qweight, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b789ca20cadb3..610650a744986 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -181,6 +181,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the From 3cef6678e7d0ee54d05cd95b9e91b8f691bed8a8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 8 Aug 2024 10:20:13 -0700 Subject: [PATCH 04/96] gptq -> marlin renaming --- tests/kernels/test_moe.py | 46 +++--- .../layers/fused_moe/__init__.py | 4 +- ...{fused_moe_gptq.py => fused_moe_marlin.py} | 28 ++-- .../layers/quantization/gptq_marlin.py | 155 +++++++++--------- 4 files changed, 113 insertions(+), 120 deletions(-) rename vllm/model_executor/layers/fused_moe/{fused_moe_gptq.py => fused_moe_marlin.py} (84%) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index d9480c8cf882e..856ee7c56e598 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -11,9 +11,10 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe -from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize, ) + marlin_quantize, +) from vllm.model_executor.models.mixtral import MixtralMoE @@ -28,10 +29,12 @@ def torch_moe(a, w1, w2, score, topk): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) def torch_moe_single(a, w, score, topk): @@ -72,8 +75,7 @@ def test_fused_moe( assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() def test_mixtral_moe(dtype: torch.dtype): """Make sure our Mixtral MoE implementation agrees with the one from @@ -94,8 +96,7 @@ def test_mixtral_moe(dtype: torch.dtype): # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) + weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data @@ -129,7 +130,8 @@ def stack_and_dev(tensors: List[torch.Tensor]): def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) # TODO: make sure this test works @@ -182,7 +184,8 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): test_perm = torch.randperm(k) w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) + w1[i].transpose(1, 0), num_bits, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1) qweight1_l.append(qweight1) scales1_l.append(scales1) @@ -204,7 +207,8 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): test_perm = torch.randperm(n) w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) + w2[i].transpose(1, 0), num_bits, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2) qweight2_l.append(qweight2) scales2_l.append(scales2) @@ -226,7 +230,7 @@ def test_fused_marlin_moe( topk, renormalize=False, ) - marlin_output = fused_moe_gptq( + marlin_output = fused_moe_marlin( a, qweight1, qweight2, @@ -290,7 +294,8 @@ def test_single_marlin_moe( for i in range(w.shape[0]): test_perm = torch.randperm(k) w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) + w[i].transpose(1, 0), num_bits, group_size, act_order, test_perm + ) w_ref_l.append(w_ref) qweights_l.append(qweight) scales_l.append(scales) @@ -304,14 +309,9 @@ def test_single_marlin_moe( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe(a, - qweight, - scales, - score, - g_idx, - sort_indices, - topk, - renormalize=False) + marlin_output = single_marlin_moe( + a, qweight, scales, score, g_idx, sort_indices, topk, renormalize=False + ) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 2b982b7ab9f86..beb94f10a557e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,4 +1,4 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON @@ -6,7 +6,7 @@ __all__ = [ "FusedMoE", "FusedMoEMethodBase", - "fused_moe_gptq", + "fused_moe_marlin", "single_marlin_moe", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py similarity index 84% rename from vllm/model_executor/layers/fused_moe/fused_moe_gptq.py rename to vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index e7c47f14f85d4..4ffcda6f85d5e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -8,7 +8,7 @@ from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config -def fused_moe_gptq( +def fused_moe_marlin( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -51,25 +51,19 @@ def fused_moe_gptq( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[ - 0], "Number of tokens mismatch" - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -87,10 +81,9 @@ def fused_moe_gptq( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False + ) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -142,5 +135,4 @@ def fused_moe_gptq( True, ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 90ffbd4360f9b..c81517823cbf2 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -12,7 +12,7 @@ set_weight_attrs, ) from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase -from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, @@ -62,15 +62,17 @@ def __init__( self.lm_head_quantized = lm_head_quantized # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - is_sym=self.is_sym) + verify_gptq_marlin_supported( + num_bits=self.weight_bits, group_size=self.group_size, is_sym=self.is_sym + ) def __repr__(self) -> str: - return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}, " - f"lm_head_quantized={self.lm_head_quantized})") + return ( + f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized})" + ) @classmethod def get_name(cls) -> str: @@ -94,37 +96,40 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "marlin" - or user_quant == "gptq_marlin") + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_marlin" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_marlin for" - " faster inference") + logger.info( + "Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference" + ) return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): return GPTQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) @@ -195,9 +200,9 @@ def create_weights( ) # Determine sharding - if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, - self.quant_config.group_size, - is_row_parallel): + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -239,10 +244,7 @@ def create_weights( # Ignore warning from fused linear layers such as QKVParallelLinear. set_weight_attrs( g_idx, - { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }, + {**extra_weight_attrs, "input_dim": 0, "ignore_warning": True}, ) # Scales @@ -291,8 +293,7 @@ def create_weights( layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) + layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, is_row_parallel) # Checkpoints are serialized in AutoGPTQ format, which is different from the # marlin format. This function is called after the weights are loaded. @@ -301,8 +302,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) + layer.workspace = marlin_make_workspace(layer.output_size_per_partition, device) # Handle sorting for activation reordering if needed. if self.quant_config.desc_act: @@ -329,8 +329,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Permute scales from autogptq format to marlin format. marlin_scales = marlin_permute_scales( layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), + size_k=( + layer.input_size + if self.quant_config.desc_act + else layer.input_size_per_partition + ), size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size, ) @@ -408,20 +411,16 @@ def create_weights( set_weight_attrs(w2_qweight, extra_weight_attrs) # up_proj scales w13_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size13, - 2 * intermediate_size, - dtype=params_dtype), + torch.empty( + num_experts, scales_size13, 2 * intermediate_size, dtype=params_dtype + ), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) # down_proj scales w2_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), + torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) @@ -454,8 +453,7 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( @@ -465,8 +463,7 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.marlin_state = GPTQMarlinState.REPACK @@ -482,42 +479,36 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( + torch.int32 + ) w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( - torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) - replace_tensor(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_tensor(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] device = layer.w13_g_idx.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) # Repack weights @@ -530,24 +521,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, layer.w2_g_idx_sort_indices, + layer.w2_qweight, + layer.w2_g_idx_sort_indices, layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2]) + layer.w2_qweight.shape[2], + ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), + size_k=( + layer.input_size + if self.quant_config.desc_act + else layer.input_size_per_partition + ), size_n=layer.w13_scales.shape[2], - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - group_size=self.quant_config.group_size) + size_n=( + layer.input_size + if self.quant_config.desc_act + else layer.input_size_per_partition + ), + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "w2_scales", marlin_w2_scales) def apply( @@ -561,7 +562,7 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, ) -> torch.Tensor: - return fused_moe_gptq( + return fused_moe_marlin( x, layer.w13_qweight, layer.w2_qweight, From a6710af0ab3bd8e1bf030d8a4f1ff14eb9afab37 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 8 Aug 2024 10:46:20 -0700 Subject: [PATCH 05/96] Undo formatting changes --- vllm/model_executor/layers/fused_moe/layer.py | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 913d6a93b0cd5..566a0a4f23be6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -54,15 +54,9 @@ def apply( class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty(num_experts, @@ -158,11 +152,10 @@ def forward_tpu( assert topk_group is None return fused_moe(x, w1, w2, router_logits, top_k, renormalize) - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -267,8 +260,8 @@ def weight_loader( else: # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: - if (param_data[expert_id] != 1 and - (param_data[expert_id] - loaded_weight).abs() > 1e-5): + if (param_data[expert_id] != 1 and (param_data[expert_id] - + loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -322,8 +315,7 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, num_expert_group=self.num_expert_group, - topk_group=self.topk_group, - ) + topk_group=self.topk_group) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -333,12 +325,9 @@ def forward(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - ) -> List[Tuple[str, str, int, int]]: + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] gate_down_up = [ ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name From e29107f1f13c73d57f02bcee7ac8a433536c831d Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 8 Aug 2024 10:47:48 -0700 Subject: [PATCH 06/96] Final formatting change --- vllm/model_executor/layers/fused_moe/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 566a0a4f23be6..25c6214318692 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -146,7 +146,6 @@ def forward_tpu( topk_group: Optional[int], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe - assert not use_grouped_topk assert num_expert_group is None assert topk_group is None From 099d61e73f5ee36e3ebf1f2c753970efecbba39b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 13:50:02 -0700 Subject: [PATCH 07/96] Switching to mixtral file for quantized mixtral --- vllm/model_executor/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94c3cea98be7b..329df4830af41 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), From bdf6bdc31d9ba050b207589315fb0f3389da389b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 14:42:45 -0700 Subject: [PATCH 08/96] Bug fixes --- tests/kernels/test_moe.py | 42 +++--- .../layers/fused_moe/fused_moe_marlin.py | 26 ++-- vllm/model_executor/layers/fused_moe/layer.py | 11 +- .../layers/quantization/gptq_marlin.py | 139 +++++++++--------- vllm/model_executor/models/mixtral.py | 10 +- 5 files changed, 122 insertions(+), 106 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 856ee7c56e598..e657581df05a0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -13,8 +13,7 @@ from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize, -) + marlin_quantize, ) from vllm.model_executor.models.mixtral import MixtralMoE @@ -29,12 +28,10 @@ def torch_moe(a, w1, w2, score, topk): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( - 0, 1 - ) - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) def torch_moe_single(a, w, score, topk): @@ -75,7 +72,8 @@ def test_fused_moe( assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() def test_mixtral_moe(dtype: torch.dtype): """Make sure our Mixtral MoE implementation agrees with the one from @@ -96,7 +94,8 @@ def test_mixtral_moe(dtype: torch.dtype): # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) + weights = (hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data @@ -130,8 +129,7 @@ def stack_and_dev(tensors: List[torch.Tensor]): def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref) - ) + torch.abs(output_ref)) # TODO: make sure this test works @@ -184,8 +182,7 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): test_perm = torch.randperm(k) w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), num_bits, group_size, act_order, test_perm - ) + w1[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) w_ref1_l.append(w_ref1) qweight1_l.append(qweight1) scales1_l.append(scales1) @@ -207,8 +204,7 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): test_perm = torch.randperm(n) w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), num_bits, group_size, act_order, test_perm - ) + w2[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) w_ref2_l.append(w_ref2) qweight2_l.append(qweight2) scales2_l.append(scales2) @@ -294,8 +290,7 @@ def test_single_marlin_moe( for i in range(w.shape[0]): test_perm = torch.randperm(k) w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), num_bits, group_size, act_order, test_perm - ) + w[i].transpose(1, 0), num_bits, group_size, act_order, test_perm) w_ref_l.append(w_ref) qweights_l.append(qweight) scales_l.append(scales) @@ -309,9 +304,14 @@ def test_single_marlin_moe( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe( - a, qweight, scales, score, g_idx, sort_indices, topk, renormalize=False - ) + marlin_output = single_marlin_moe(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 4ffcda6f85d5e..d84126568d726 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -51,19 +51,25 @@ def fused_moe_marlin( - torch.Tensor: The output tensor after applying the MoE layer. """ # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -81,9 +87,10 @@ def fused_moe_marlin( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros( - max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False - ) + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -135,4 +142,5 @@ def fused_moe_marlin( True, ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 25c6214318692..214c40a510dfb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -55,8 +55,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty(num_experts, @@ -151,6 +151,7 @@ def forward_tpu( assert topk_group is None return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -259,8 +260,8 @@ def weight_loader( else: # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: - if (param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5): + if (param_data[expert_id] != 1 and + (param_data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -326,7 +327,7 @@ def forward(self, hidden_states: torch.Tensor, def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: + num_experts: int) -> List[Tuple[str, str, int, int]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] gate_down_up = [ ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c81517823cbf2..2088177418f1e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -62,17 +62,15 @@ def __init__( self.lm_head_quantized = lm_head_quantized # Verify supported on platform. - verify_gptq_marlin_supported( - num_bits=self.weight_bits, group_size=self.group_size, is_sym=self.is_sym - ) + verify_gptq_marlin_supported(num_bits=self.weight_bits, + group_size=self.group_size, + is_sym=self.is_sym) def __repr__(self) -> str: - return ( - f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}, " - f"lm_head_quantized={self.lm_head_quantized})" - ) + return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: @@ -96,40 +94,37 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized) @classmethod - def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = ( - user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" - ) + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "gptq_marlin") if can_convert and is_valid_user_quant: - msg = ( - "The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name()) - ) + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info( - "Detected that the model can run with gptq_marlin" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_marlin for" - " faster inference" - ) + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: - if isinstance(layer, LinearBase) or ( - isinstance(layer, ParallelLMHead) and self.lm_head_quantized - ): + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) @@ -200,9 +195,9 @@ def create_weights( ) # Determine sharding - if marlin_repeat_scales_on_all_ranks( - self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel - ): + if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -244,7 +239,10 @@ def create_weights( # Ignore warning from fused linear layers such as QKVParallelLinear. set_weight_attrs( g_idx, - {**extra_weight_attrs, "input_dim": 0, "ignore_warning": True}, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, ) # Scales @@ -293,7 +291,8 @@ def create_weights( layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, is_row_parallel) + layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, + is_row_parallel) # Checkpoints are serialized in AutoGPTQ format, which is different from the # marlin format. This function is called after the weights are loaded. @@ -302,7 +301,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device # Allocate marlin workspace - layer.workspace = marlin_make_workspace(layer.output_size_per_partition, device) + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) # Handle sorting for activation reordering if needed. if self.quant_config.desc_act: @@ -329,11 +329,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Permute scales from autogptq format to marlin format. marlin_scales = marlin_permute_scales( layer.scales, - size_k=( - layer.input_size - if self.quant_config.desc_act - else layer.input_size_per_partition - ), + size_k=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size, ) @@ -411,16 +408,20 @@ def create_weights( set_weight_attrs(w2_qweight, extra_weight_attrs) # up_proj scales w13_scales = torch.nn.Parameter( - torch.empty( - num_experts, scales_size13, 2 * intermediate_size, dtype=params_dtype - ), + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) # down_proj scales w2_scales = torch.nn.Parameter( - torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype), + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) @@ -453,7 +454,8 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( @@ -463,7 +465,8 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.marlin_state = GPTQMarlinState.REPACK @@ -479,36 +482,42 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( - torch.int32 - ) + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( - torch.int32 - ) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) - replace_tensor(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) - replace_tensor(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] device = layer.w13_g_idx.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), requires_grad=False, ) # Repack weights @@ -530,11 +539,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, - size_k=( - layer.input_size - if self.quant_config.desc_act - else layer.input_size_per_partition - ), + size_k=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -542,11 +548,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=( - layer.input_size - if self.quant_config.desc_act - else layer.input_size_per_partition - ), + size_n=(layer.input_size if self.quant_config.desc_act else + layer.input_size_per_partition), group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8fbd537a2c031..d5c4256ded522 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -437,7 +437,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, + loaded_weight, + shard_id, + is_quantized=True) break else: for mapping in expert_params_mapping: @@ -454,7 +457,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight, weight_name, shard_id=shard_id, - expert_id=expert_id) + expert_id=expert_id, + is_quantized=True) break else: # Skip loading extra bias for GPTQ models. @@ -471,4 +475,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader(param, loaded_weight, is_quantized=True) From 19c5c59d82b8329be8e1cff2f0b51e40e83e83a3 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 14:45:20 -0700 Subject: [PATCH 09/96] is quantized change --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d5c4256ded522..c36f86f2d65f0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -475,4 +475,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, is_quantized=True) + weight_loader(param, loaded_weight) From 3b7cc60a50478c180b0a48744fa37aab3084e413 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 14:56:12 -0700 Subject: [PATCH 10/96] debug stat --- vllm/model_executor/models/mixtral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c36f86f2d65f0..341ff806f37d5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -50,7 +50,8 @@ from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers - +import logging +logger = logging.getLogger(__name__) class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -451,6 +452,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + logger.error(params_dict.keys()) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, From d2c4754df8418439acb7406acbb51e66461b6aa2 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 14:59:35 -0700 Subject: [PATCH 11/96] replace wiehgt name with param name --- vllm/model_executor/layers/fused_moe/__init__.py | 2 +- vllm/model_executor/models/mixtral.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index beb94f10a557e..0d871232305ae 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,4 +1,4 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin +sfrom vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 341ff806f37d5..e41250cf99707 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -457,7 +457,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + param_name, shard_id=shard_id, expert_id=expert_id, is_quantized=True) From f579cb25242aa33514e7fe27068e2000e87d6141 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:00:26 -0700 Subject: [PATCH 12/96] typo --- vllm/model_executor/layers/fused_moe/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 0d871232305ae..beb94f10a557e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,4 +1,4 @@ -sfrom vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON From 79394eb8e23fb703a12a9a28fdf283dfa2599080 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:08:05 -0700 Subject: [PATCH 13/96] debug --- vllm/model_executor/models/mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e41250cf99707..40948bfe1f0b5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -444,6 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): is_quantized=True) break else: + logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: From ec75f4ef3c036e2a2e794098cca804956c5c9b96 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:09:53 -0700 Subject: [PATCH 14/96] more debug --- vllm/model_executor/models/mixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 40948bfe1f0b5..5d2a07a7b50f0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -449,7 +449,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue + logger.error(weight_name, param_name, name, name.replace(weight_name, param_name)) name = name.replace(weight_name, param_name) + logger.error(name in params_dict) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From 91ca97078a59d2722965599feb4235804fbd5e1f Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:11:05 -0700 Subject: [PATCH 15/96] only relevant logging --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5d2a07a7b50f0..1cfa219f9d268 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -444,7 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): is_quantized=True) break else: - logger.error(expert_params_mapping) + # logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: @@ -455,7 +455,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - logger.error(params_dict.keys()) + # logger.error(params_dict.keys()) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, From 1b9d5bb25d68fee931803789251b93bc35d8fb82 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:12:54 -0700 Subject: [PATCH 16/96] log --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1cfa219f9d268..83277083e24fa 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -449,9 +449,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - logger.error(weight_name, param_name, name, name.replace(weight_name, param_name)) + logger.error(weight_name, param_name, name) name = name.replace(weight_name, param_name) - logger.error(name in params_dict) + logger.error(f"Loading {name} from {weight_name}") # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From ec0671913b4d72ad8bde7f2ee695c91a7ac6c311 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:14:03 -0700 Subject: [PATCH 17/96] log --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 83277083e24fa..71e68129e9f80 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -449,7 +449,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - logger.error(weight_name, param_name, name) + logger.error(f"{weight_name} {param_name} {name}") name = name.replace(weight_name, param_name) logger.error(f"Loading {name} from {weight_name}") # Skip layers on other devices. From 71d82e125319878825e8c9bed84143b95af7cc4a Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 15:31:23 -0700 Subject: [PATCH 18/96] removing qzero weights --- .../layers/quantization/gptq_marlin.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 2088177418f1e..542c0244f431f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -263,31 +263,31 @@ def create_weights( }, ) - # Quantized zero-points - qzeros = Parameter( - torch.empty( - scales_and_zp_size, - output_size_per_partition // self.quant_config.pack_factor, - dtype=torch.int32, - device="meta", - ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }, - ) + # # Quantized zero-points + # qzeros = Parameter( + # torch.empty( + # scales_and_zp_size, + # output_size_per_partition // self.quant_config.pack_factor, + # dtype=torch.int32, + # device="meta", + # ), + # requires_grad=False, + # ) + # set_weight_attrs( + # qzeros, + # { + # **extra_weight_attrs, + # "input_dim": scales_and_zp_input_dim, + # "output_dim": 1, + # "packed_dim": 1, + # "pack_factor": self.quant_config.pack_factor, + # }, + # ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) - layer.register_parameter("qzeros", qzeros) + # layer.register_parameter("qzeros", qzeros) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size From d3465d07813a09747d5d439b16b57cdbd7c2d566 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:02:10 -0700 Subject: [PATCH 19/96] Qzeors in expert mapping --- vllm/model_executor/layers/fused_moe/layer.py | 11 ++++ .../layers/quantization/gptq_marlin.py | 62 ++++++++++++------- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 214c40a510dfb..d89867090ec49 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -388,4 +388,15 @@ def make_expert_params_mapping( shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the qzeros for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qzeros" + if weight_name in gate_up else "experts.w2_qzeros", + f"experts.{expert_id}.{weight_name}.qzeros", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) ]) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 542c0244f431f..9b196477fbd21 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -263,31 +263,31 @@ def create_weights( }, ) - # # Quantized zero-points - # qzeros = Parameter( - # torch.empty( - # scales_and_zp_size, - # output_size_per_partition // self.quant_config.pack_factor, - # dtype=torch.int32, - # device="meta", - # ), - # requires_grad=False, - # ) - # set_weight_attrs( - # qzeros, - # { - # **extra_weight_attrs, - # "input_dim": scales_and_zp_input_dim, - # "output_dim": 1, - # "packed_dim": 1, - # "pack_factor": self.quant_config.pack_factor, - # }, - # ) + # Quantized zero-points + qzeros = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device="meta", + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) - # layer.register_parameter("qzeros", qzeros) + layer.register_parameter("qzeros", qzeros) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size @@ -426,6 +426,26 @@ def create_weights( ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, From 226ee265d6daf964cad6a22536cbf23e3fcc68da Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:04:26 -0700 Subject: [PATCH 20/96] Debug --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 71e68129e9f80..054f8a48593c6 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -444,7 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): is_quantized=True) break else: - # logger.error(expert_params_mapping) + logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: @@ -455,7 +455,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - # logger.error(params_dict.keys()) + logger.error(params_dict.keys()) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, From 21d7d27de837d05c0368addfd6d19fcaf23ea7ee Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:07:09 -0700 Subject: [PATCH 21/96] Load qzero --- vllm/model_executor/layers/fused_moe/layer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d89867090ec49..24120b2d5b2f2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -255,6 +255,11 @@ def weight_loader( raise ValueError(f"Invalid weight name: {weight_name}: " "must contain 'w13' or 'w2'.") param_data[expert_id] = loaded_weight + elif "qzeros" in weight_name: + if "w13" not in weight_name and "w2" not in weight_name: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + param_data[expert_id] = loaded_weight else: raise ValueError(f"Invalid weight name: {weight_name}.") else: From 2dabb4b9aec04ec601c0d4879a21dd0d5fcce0a7 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:12:29 -0700 Subject: [PATCH 22/96] rm 2x --- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9b196477fbd21..efc854fe72e36 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -430,7 +430,7 @@ def create_weights( w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, scales_size13, - 2 * intermediate_size // self.quant_config.pack_factor, + intermediate_size // self.quant_config.pack_factor, dtype=params_dtype), requires_grad=False, ) From 63669768bfcabafe8eda18ff60432337ee6a2889 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:15:05 -0700 Subject: [PATCH 23/96] Mapping for scales --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 24120b2d5b2f2..60e9912c6b6ad 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -360,6 +360,17 @@ def make_expert_params_mapping( shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_scales" + if weight_name in gate_up else "experts.w2_scales", + f"experts.{expert_id}.{weight_name}.scales", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) From d63c0966ed8d8d30a97fdd259c150dcd4a56fdb8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:15:54 -0700 Subject: [PATCH 24/96] rm logging --- vllm/model_executor/models/mixtral.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 054f8a48593c6..c9f0d872e7cc2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -444,18 +444,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): is_quantized=True) break else: - logger.error(expert_params_mapping) + # logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - logger.error(f"{weight_name} {param_name} {name}") + # logger.error(f"{weight_name} {param_name} {name}") name = name.replace(weight_name, param_name) - logger.error(f"Loading {name} from {weight_name}") + # logger.error(f"Loading {name} from {weight_name}") # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - logger.error(params_dict.keys()) + # logger.error(params_dict.keys()) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, From 360fef4273a05b936a06f92b6c7edf81a36ac2d9 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:16:14 -0700 Subject: [PATCH 25/96] Adding lyaer wise logging --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c9f0d872e7cc2..ec7019a39d218 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -451,7 +451,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # logger.error(f"{weight_name} {param_name} {name}") name = name.replace(weight_name, param_name) - # logger.error(f"Loading {name} from {weight_name}") + logger.error(f"Loading {name} from {weight_name}") # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From c23d6169261c40136b5501957d9ff71064ef5bd8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:40:25 -0700 Subject: [PATCH 26/96] shard ids --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ec7019a39d218..4a76689855123 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -444,7 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): is_quantized=True) break else: - # logger.error(expert_params_mapping) + logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: From 8d81d14fc3e9b462410f46a369ea656e27883a59 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:45:51 -0700 Subject: [PATCH 27/96] Loading qzero correctly --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++-------- .../model_executor/layers/quantization/gptq_marlin.py | 2 +- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 60e9912c6b6ad..41b24bbc5ee41 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -235,16 +235,16 @@ def weight_loader( param_data = param.data if is_quantized: - if "_qweight" in weight_name or "_scales" in weight_name: + if ["_qweight", "_scales", "_qzeros"] in weight_name: if "w13" in weight_name: shard_size = self.intermediate_size_per_partition if shard_id == 0: param_data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == 1: + elif shard_id == 2: param_data[expert_id, :, shard_size:] = loaded_weight else: raise ValueError(f"Invalid shard_id: {shard_id}: " - "must be 0 or 1.") + "must be 0 or 2.") elif "w2" in weight_name: param_data[expert_id][:] = loaded_weight else: @@ -255,11 +255,6 @@ def weight_loader( raise ValueError(f"Invalid weight name: {weight_name}: " "must contain 'w13' or 'w2'.") param_data[expert_id] = loaded_weight - elif "qzeros" in weight_name: - if "w13" not in weight_name and "w2" not in weight_name: - raise ValueError(f"Invalid weight name: {weight_name}: " - "must contain 'w13' or 'w2'.") - param_data[expert_id] = loaded_weight else: raise ValueError(f"Invalid weight name: {weight_name}.") else: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efc854fe72e36..9b196477fbd21 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -430,7 +430,7 @@ def create_weights( w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, scales_size13, - intermediate_size // self.quant_config.pack_factor, + 2 * intermediate_size // self.quant_config.pack_factor, dtype=params_dtype), requires_grad=False, ) From 22e1aa7b6f1b1ced2323a7c2c12437fb7208d36e Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:46:33 -0700 Subject: [PATCH 28/96] List operand --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 41b24bbc5ee41..a1508bb4b34e2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -235,7 +235,7 @@ def weight_loader( param_data = param.data if is_quantized: - if ["_qweight", "_scales", "_qzeros"] in weight_name: + if weight_name in ["_qweight", "_scales", "_qzeros"]: if "w13" in weight_name: shard_size = self.intermediate_size_per_partition if shard_id == 0: From 81e01f383bc84aed8ce4fae093cb7d3fe50188a3 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:49:09 -0700 Subject: [PATCH 29/96] If clause --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a1508bb4b34e2..4dbcdd57f8337 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -235,7 +235,7 @@ def weight_loader( param_data = param.data if is_quantized: - if weight_name in ["_qweight", "_scales", "_qzeros"]: + if "_qweight" in weight_name or "_scales" in weight_name or "_qzeros" in weight_name: if "w13" in weight_name: shard_size = self.intermediate_size_per_partition if shard_id == 0: From dcfd32d1aeccb98dec208d474d5020c7d21e7cf1 Mon Sep 17 00:00:00 2001 From: Dhruva Bansal Date: Mon, 12 Aug 2024 23:57:20 +0000 Subject: [PATCH 30/96] Able to load layers --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4dbcdd57f8337..686d0415c5e5c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -237,7 +237,7 @@ def weight_loader( if is_quantized: if "_qweight" in weight_name or "_scales" in weight_name or "_qzeros" in weight_name: if "w13" in weight_name: - shard_size = self.intermediate_size_per_partition + shard_size = loaded_weight.size()[-1] if shard_id == 0: param_data[expert_id, :, :shard_size] = loaded_weight elif shard_id == 2: From f04cbeaaaceb777515528a7cc66657499c478b17 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:58:44 -0700 Subject: [PATCH 31/96] Setting load quant to false --- vllm/model_executor/models/mixtral.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4a76689855123..2043580ab8f1c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -440,8 +440,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - shard_id, - is_quantized=True) + shard_id) break else: logger.error(expert_params_mapping) From a56821d352438f24dfedfc2413901c9c511d88f1 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 16:59:44 -0700 Subject: [PATCH 32/96] Disabling logging --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2043580ab8f1c..011b10583c528 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -443,14 +443,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): shard_id) break else: - logger.error(expert_params_mapping) + # logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue # logger.error(f"{weight_name} {param_name} {name}") name = name.replace(weight_name, param_name) - logger.error(f"Loading {name} from {weight_name}") + # logger.error(f"Loading {name} from {weight_name}") # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From 7f961c6d4d013b281a22dac79b6b073fdf8a122c Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:19:12 -0700 Subject: [PATCH 33/96] Removing *2 in marlin moe repack --- vllm/_custom_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 048ab9195d24e..f9f4b9c725dda 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): From 4a6c7ffc776b1c8580ada27c3148172d24d1c902 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:28:11 -0700 Subject: [PATCH 34/96] *4 in marlin moe repack --- vllm/_custom_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f9f4b9c725dda..1458103fa1b6d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n), + output = torch.empty((num_experts, size_k // 16, size_n * 4), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): From e6cd286d44d27db250737bc22a086e8a50082117 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:29:36 -0700 Subject: [PATCH 35/96] bits --- vllm/_custom_ops.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1458103fa1b6d..048ab9195d24e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 4), + output = torch.empty((num_experts, size_k // 16, size_n * 2), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9b196477fbd21..6133a4e172f4c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -545,7 +545,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qweight, layer.w13_g_idx_sort_indices, layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], + layer.w13_qweight.shape[2] * 2, self.quant_config.weight_bits, ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) @@ -554,6 +554,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_g_idx_sort_indices, layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[2], + self.quant_config.weight_bits, ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) # Repack scales From 90241c4e4aa7532cf0f29211cf7aa3c700974843 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:30:28 -0700 Subject: [PATCH 36/96] *4 --- vllm/_custom_ops.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 048ab9195d24e..1458103fa1b6d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * 4), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 6133a4e172f4c..ca756d652f928 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -545,7 +545,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qweight, layer.w13_g_idx_sort_indices, layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2] * 2, + layer.w13_qweight.shape[2], self.quant_config.weight_bits, ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) From 67409e93237025e6db6027934711da1da87fd878 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:38:25 -0700 Subject: [PATCH 37/96] intermediate size --- vllm/model_executor/layers/quantization/gptq_marlin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index ca756d652f928..70b31a066740b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -560,8 +560,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -569,8 +569,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) From 539032ef58cfad750d93365f1e91c64464505081 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:38:59 -0700 Subject: [PATCH 38/96] repeat keyword --- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 70b31a066740b..9adeac6f96417 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -569,7 +569,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_k=(layer.intermediate_size if self.quant_config.desc_act else + size_n=(layer.intermediate_size if self.quant_config.desc_act else layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, ) From 57b1cbe81846613b46f55fcbf7f41712a1ad4556 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:41:06 -0700 Subject: [PATCH 39/96] hidden size --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9adeac6f96417..78378b271f1de 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -569,8 +569,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=(layer.intermediate_size if self.quant_config.desc_act else - layer.intermediate_size_per_partition), + size_n=layer.hidden_size, group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) From 87f1dd4cd3a2679b764105850ba60e2bbb1e1999 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:50:47 -0700 Subject: [PATCH 40/96] intermediate size back --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 78378b271f1de..9adeac6f96417 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -569,7 +569,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=layer.hidden_size, + size_n=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) From 4c073c2a99c9030c705fb9870f73e3d09560f74f Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 17:59:44 -0700 Subject: [PATCH 41/96] permute scales w3 --- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9adeac6f96417..52bef66f57294 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -568,7 +568,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_k=layer.w2_scales.shape[1], size_n=(layer.intermediate_size if self.quant_config.desc_act else layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, From d73249346a57999c6639b689d90085f2088f1586 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 18:00:34 -0700 Subject: [PATCH 42/96] *2 --- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 52bef66f57294..617628d268f11 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -568,7 +568,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1], + size_k=2 * layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_n=(layer.intermediate_size if self.quant_config.desc_act else layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, From fdc22c4ef1a3bc8afce41c25ad8ec19acc4bcac7 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 18:06:12 -0700 Subject: [PATCH 43/96] log --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 617628d268f11..fe2a8c7d2f6ca 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -566,9 +566,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) replace_tensor(layer, "w13_scales", marlin_w13_scales) + logger.error(f"{layer.w2_scales.size()}, {layer.intermediate_size_per_partition}, {self.quant_config.group_size}") marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=2 * layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_n=(layer.intermediate_size if self.quant_config.desc_act else layer.intermediate_size_per_partition), group_size=self.quant_config.group_size, From 272822eafbec55b0f39415770f009177339c7986 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 18:10:05 -0700 Subject: [PATCH 44/96] shape as 2 --- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index fe2a8c7d2f6ca..911fbb154fce6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -570,8 +570,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, - size_n=(layer.intermediate_size if self.quant_config.desc_act else - layer.intermediate_size_per_partition), + size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) From 3ce045e5c19fd562d7f69d5a0ae06fc422317017 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 18:21:08 -0700 Subject: [PATCH 45/96] test --- vllm/_custom_ops.py | 2 +- .../layers/quantization/gptq_marlin.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1458103fa1b6d..048ab9195d24e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 4), + output = torch.empty((num_experts, size_k // 16, size_n * 2), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 911fbb154fce6..70773505dcb2d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -541,14 +541,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + # marlin_w13_qweight = ops.gptq_marlin_moe_repack( + # layer.w13_qweight, + # layer.w13_g_idx_sort_indices, + # layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + # layer.w13_qweight.shape[2], + # self.quant_config.weight_bits, + # ) + # replace_tensor(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, From c4ba4779bd44f147ea8b745106524f45b3db32f1 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 12 Aug 2024 23:17:10 -0700 Subject: [PATCH 46/96] Increasing to 4 and changing assert --- vllm/_custom_ops.py | 2 +- .../layers/fused_moe/fused_moe_marlin.py | 2 +- .../layers/quantization/gptq_marlin.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 048ab9195d24e..1458103fa1b6d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * 4), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index d84126568d726..75e02af9d77af 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -56,7 +56,7 @@ def fused_moe_marlin( assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + 1] == w2.shape[2] // 4, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 70773505dcb2d..911fbb154fce6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -541,14 +541,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) # Repack weights - # marlin_w13_qweight = ops.gptq_marlin_moe_repack( - # layer.w13_qweight, - # layer.w13_g_idx_sort_indices, - # layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - # layer.w13_qweight.shape[2], - # self.quant_config.weight_bits, - # ) - # replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, From 2ea8370e136d58c085b703f779f6f029649f8fe0 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 12:00:07 -0700 Subject: [PATCH 47/96] logging --- vllm/model_executor/layers/quantization/gptq_marlin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 911fbb154fce6..c07d7b10ef193 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -540,6 +540,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device=device), requires_grad=False, ) + logger.error(f"W13 qweight size - {layer.w13_qweight.size()}") + logger.error(f"Quant Config: {self.quant_config}") # Repack weights marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, From 8287025224cda4529d624a8c2aceebaecec4ddd6 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 12:32:13 -0700 Subject: [PATCH 48/96] marlin moe repack change --- vllm/_custom_ops.py | 8 +++++++- vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 8 ++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1458103fa1b6d..37787aa6d3b2f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,7 +283,13 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 4), + # output = torch.empty((num_experts, size_k // 16, size_n * 2), + # device=b_q_weight.device, + # dtype=b_q_weight.dtype) + # for e in range(num_experts): + # output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + # size_k, size_n, num_bits) + output = torch.empty((num_experts, size_k, size_n), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 75e02af9d77af..d84126568d726 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -56,7 +56,7 @@ def fused_moe_marlin( assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[ - 1] == w2.shape[2] // 4, "Hidden size mismatch w2" + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index c07d7b10ef193..ac5e018193a97 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -491,7 +491,6 @@ def create_weights( layer.marlin_state = GPTQMarlinState.REPACK def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.marlin_state = GPTQMarlinState.READY # Process act_order if self.quant_config.desc_act: @@ -546,7 +545,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[1], layer.w13_qweight.shape[2], self.quant_config.weight_bits, ) @@ -554,7 +553,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[1], layer.w2_qweight.shape[2], self.quant_config.weight_bits, ) @@ -568,7 +567,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) replace_tensor(layer, "w13_scales", marlin_w13_scales) - logger.error(f"{layer.w2_scales.size()}, {layer.intermediate_size_per_partition}, {self.quant_config.group_size}") + # logger.error(f"{layer.w2_scales.size()}, {layer.intermediate_size_per_partition}, {self.quant_config.group_size}") marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, @@ -576,6 +575,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) + layer.marlin_state = GPTQMarlinState.READY def apply( self, From 53b23b9525c474b655dd0fe1d0c6787a72bec60e Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 12:42:59 -0700 Subject: [PATCH 49/96] mult qweight shape by pack factor --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index ac5e018193a97..7b376a79004b2 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -545,7 +545,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1], + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, layer.w13_qweight.shape[2], self.quant_config.weight_bits, ) @@ -553,7 +553,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1], + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, layer.w2_qweight.shape[2], self.quant_config.weight_bits, ) From bc407861ecb5011f390507d31ed1b13318a08f67 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 12:59:52 -0700 Subject: [PATCH 50/96] Potential support for 8 bit --- vllm/_custom_ops.py | 8 +------- vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 9 +++++---- vllm/model_executor/layers/quantization/gptq_marlin.py | 1 + 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 37787aa6d3b2f..92c846cd685f6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -283,13 +283,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - # output = torch.empty((num_experts, size_k // 16, size_n * 2), - # device=b_q_weight.device, - # dtype=b_q_weight.dtype) - # for e in range(num_experts): - # output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], - # size_k, size_n, num_bits) - output = torch.empty((num_experts, size_k, size_n), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index d84126568d726..44ea7299fe447 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -23,6 +23,7 @@ def fused_moe_marlin( use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -56,7 +57,7 @@ def fused_moe_marlin( assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + 1] == w2.shape[2] // (num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -86,7 +87,7 @@ def fused_moe_marlin( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + max_workspace_size = ((M + 255) // 256) * (max((num_bits // 2) * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", @@ -109,7 +110,7 @@ def fused_moe_marlin( rand_perm1, workspace, M, - 2 * N, + (num_bits // 2) * N, K, True, E, @@ -119,7 +120,7 @@ def fused_moe_marlin( False, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, (num_bits // 2) * N)) intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache2, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 7b376a79004b2..1f178d828ec94 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -601,4 +601,5 @@ def apply( renormalize=renormalize, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.weight_bits, ) From bea13de41d2b4f7db75ea67893615c04f83a10f1 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 13:05:40 -0700 Subject: [PATCH 51/96] undo change --- vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 44ea7299fe447..efafcef2f1ee7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -87,7 +87,7 @@ def fused_moe_marlin( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max((num_bits // 2) * N, K) // 64) * 16 + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", @@ -110,7 +110,7 @@ def fused_moe_marlin( rand_perm1, workspace, M, - (num_bits // 2) * N, + 2 * N, K, True, E, From a3a9114b00109833a780e8b6f121a24a622d42b7 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:22:50 -0700 Subject: [PATCH 52/96] qzeros --- vllm/model_executor/layers/fused_moe/layer.py | 11 -- vllm/model_executor/models/mixtral.py | 7 +- vllm/model_executor/models/mixtral_quant.py | 144 +++++++++++++----- 3 files changed, 112 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 686d0415c5e5c..53c682385e840 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -399,15 +399,4 @@ def make_expert_params_mapping( shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) - ] + [ - # These are the qzeros for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - "experts.w13_qzeros" - if weight_name in gate_up else "experts.w2_qzeros", - f"experts.{expert_id}.{weight_name}.qzeros", - expert_id, - shard_id, - ) for expert_id in range(num_experts) - for shard_id, weight_name in enumerate(gate_down_up) ]) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 011b10583c528..38b9f4ee24c0c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -446,7 +446,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: + if weight_name not in name or ".qzeros" in name: continue # logger.error(f"{weight_name} {param_name} {name}") name = name.replace(weight_name, param_name) @@ -466,7 +466,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") and name not in params_dict: + continue + + if ".qzeros" in name: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index cdfd24874b974..85dafd55bbcf8 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -34,6 +35,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -94,10 +96,13 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config + self.use_fused_moe = use_fused_moe + self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts @@ -113,14 +118,26 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) + if self.use_fused_moe: + params_dtype = torch.float16 + self.experts = FusedMoE(num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=self.tp_size) + else: + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, @@ -132,28 +149,34 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits, _ = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + if self.use_fused_moe: + ret = self.experts(hidden_states.half(), router_logits) + return ret.bfloat16() + else: + routing_weights = F.softmax(router_logits, + dim=1, + dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) class MixtralAttention(nn.Module): @@ -238,6 +261,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -254,6 +278,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, + use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -294,6 +319,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -307,6 +333,7 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, + use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -343,10 +370,12 @@ def __init__( super().__init__() # TODO check runs with dtype=float16 + self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, cache_config, quant_config) + self.model = MixtralModel(config, self.use_fused_moe, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -407,9 +436,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if self.use_fused_moe: + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): + continue + + if (".qzeros" in name): + continue + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") + + else: + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + + param = params_dict[name] + + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From eb916f9584f7056e01872e8219fa3399c356640e Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:25:12 -0700 Subject: [PATCH 53/96] switching traffic to mixtral quant --- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- vllm/model_executor/models/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1f178d828ec94..90762efef8108 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -539,8 +539,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device=device), requires_grad=False, ) - logger.error(f"W13 qweight size - {layer.w13_qweight.size()}") - logger.error(f"Quant Config: {self.quant_config}") + # logger.error(f"W13 qweight size - {layer.w13_qweight.size()}") + # logger.error(f"Quant Config: {self.quant_config}") # Repack weights marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 329df4830af41..94c3cea98be7b 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), From 017d6f80f1d3078ca1e705c948d2804ac9db0e94 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:28:16 -0700 Subject: [PATCH 54/96] compat --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 53c682385e840..0b06ee86a308d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -240,11 +240,11 @@ def weight_loader( shard_size = loaded_weight.size()[-1] if shard_id == 0: param_data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == 2: + elif shard_id == 2 or shard_id == 1: param_data[expert_id, :, shard_size:] = loaded_weight else: raise ValueError(f"Invalid shard_id: {shard_id}: " - "must be 0 or 2.") + "must be 0, 1, or 2.") elif "w2" in weight_name: param_data[expert_id][:] = loaded_weight else: From eb9c0870afc85ad1db73f52d215164cf9a388cb0 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:42:31 -0700 Subject: [PATCH 55/96] Passing intermediate tensor into mixtral in quant file --- vllm/model_executor/models/mixtral.py | 2 +- vllm/model_executor/models/mixtral_quant.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 38b9f4ee24c0c..a70b31f82c070 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -341,7 +341,7 @@ def __init__( self.config = config self.lora_config = lora_config - + self.quant_config = quant_config self.model = MixtralModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 85dafd55bbcf8..9c8e2b395992f 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -391,8 +391,22 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: From ea3cf18c457305bc4b257b3dadf8047eea607f90 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:45:06 -0700 Subject: [PATCH 56/96] Removing intemediate tensors from forward --- vllm/model_executor/models/mixtral.py | 23 ++++++--------------- vllm/model_executor/models/mixtral_quant.py | 2 +- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a70b31f82c070..0c729d96d5707 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -283,25 +283,14 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + kv_caches[i], attn_metadata, + residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -373,7 +362,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9c8e2b395992f..a379c3e0f1bb8 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -391,7 +391,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata) return hidden_states def make_empty_intermediate_tensors( From 4f6b4caaf07c3b5af3e8cbc8a359fde77aee02b1 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:50:24 -0700 Subject: [PATCH 57/96] load weights from quant --- vllm/model_executor/models/mixtral.py | 115 +++++++++++--------- vllm/model_executor/models/mixtral_quant.py | 14 --- 2 files changed, 61 insertions(+), 68 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0c729d96d5707..fbe8a3530116f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -283,14 +283,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -362,7 +373,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -401,19 +412,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -421,54 +423,59 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - shard_id) + weight_loader(param, loaded_weight, shard_id) break else: - # logger.error(expert_params_mapping) - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name or ".qzeros" in name: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if self.use_fused_moe: + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): continue - # logger.error(f"{weight_name} {param_name} {name}") - name = name.replace(weight_name, param_name) - # logger.error(f"Loading {name} from {weight_name}") - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + + if (".qzeros" in name): continue - # logger.error(params_dict.keys()) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - param_name, - shard_id=shard_id, - expert_id=expert_id, - is_quantized=True) - break + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") + else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if ".qzeros" in name: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: + if ("block_sparse_moe.experts." in name + and name not in params_dict): continue - param = params_dict[name] + param = params_dict[name] + + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index a379c3e0f1bb8..85dafd55bbcf8 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -393,20 +393,6 @@ def forward( hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: From 7ec27d9722c404b1cd3c7872ba7e92c9a722ff96 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:53:09 -0700 Subject: [PATCH 58/96] Mixtral load weights change: --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/mixtral.py | 79 +++++++++++--------------- 2 files changed, 35 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94c3cea98be7b..329df4830af41 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fbe8a3530116f..1b59b1f7c2852 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only Mixtral model.""" from typing import Iterable, List, Optional, Tuple - +import re import torch from torch import nn from transformers import MixtralConfig @@ -432,50 +432,39 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): + continue + + if (".qzeros" in name): + continue + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") param = params_dict[name] - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) From aa1fe77b71954434747ac8a27dcbb03446fd4a8a Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 13 Aug 2024 14:54:44 -0700 Subject: [PATCH 59/96] none shard id change --- vllm/model_executor/models/mixtral.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1b59b1f7c2852..18d377922ef03 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -464,7 +464,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) + if shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From ae8fb1542b9cac9c32a610a28f606f63da38ab4a Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 10:37:12 -0700 Subject: [PATCH 60/96] Use class from mixtral_quant --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/mixtral.py | 403 +++++++++++++++++-------- 2 files changed, 279 insertions(+), 126 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 329df4830af41..72a13d13eb0d6 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral", "QuantizedMixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 18d377922ef03..05bbabe3e9278 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -32,26 +32,31 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers -import logging -logger = logging.getLogger(__name__) + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -61,36 +66,42 @@ class MixtralMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -103,7 +114,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -160,12 +170,14 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -183,7 +195,6 @@ def forward( class MixtralDecoderLayer(nn.Module): - def __init__( self, config: MixtralConfig, @@ -203,18 +214,20 @@ def __init__( rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + prefix=f"{prefix}.block_sparse_moe", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -229,8 +242,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -239,14 +251,12 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual class MixtralModel(nn.Module): - def __init__( self, config: MixtralConfig, @@ -257,8 +267,11 @@ def __init__( ) -> None: super().__init__() self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -273,7 +286,8 @@ def __init__( lambda prefix: MixtralDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -294,14 +308,17 @@ def forward( residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -341,12 +358,10 @@ def __init__( self.config = config self.lora_config = lora_config - self.quant_config = quant_config - self.model = MixtralModel(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + + self.model = MixtralModel( + config, cache_config, quant_config, lora_config=lora_config, prefix="model" + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -360,8 +375,9 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.sampler = Sampler() def forward( @@ -372,29 +388,30 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + hidden_states = self.model( + input_ids, positions, kv_caches, attn_metadata, intermediate_tensors + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) def sample( self, @@ -412,11 +429,137 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id, is_quantized=True) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + is_quantized=True, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +class QuantizedMixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + # TODO check runs with dtype=float16 + self.use_fused_moe = config.torch_dtype != torch.float8_e4m3fn + + self.config = config + self.quant_config = quant_config + self.model = MixtralModel( + config, self.use_fused_moe, cache_config, quant_config + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -432,44 +575,54 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name + if self.use_fused_moe: + if ( + "block_sparse_moe.experts." in name + and ".w1." not in name + and ".w2." not in name and ".w3." not in name - and name not in params_dict): - continue + and name not in params_dict + ): + continue + + if ".qzeros" in name: + continue + + shard_id = None + expert_id = 0 + + has_any_numbered = ( + ".qweight" in name or ".scales" in name or ".g_idx" in name + ) + if has_any_numbered and (".w1." in name): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if has_any_numbered and (".w2." in name): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if has_any_numbered and (".w3." in name): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") + else: + if "block_sparse_moe.experts." in name and name not in params_dict: + continue param = params_dict[name] - if shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, name, shard_id, expert_id, True) else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) From b863981837d61a6d465e28c0c46e2304cb47333b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 10:39:27 -0700 Subject: [PATCH 61/96] Removing lora from mixtral model init --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 05bbabe3e9278..652ad21d12ab4 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -514,7 +514,7 @@ def __init__( self.config = config self.quant_config = quant_config self.model = MixtralModel( - config, self.use_fused_moe, cache_config, quant_config + config, cache_config, quant_config, None, prefix="model" ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config From 5556d284abbaa7faacd627ea32b433ad5b702f5b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 10:41:14 -0700 Subject: [PATCH 62/96] Adding empty intermediate tensors --- vllm/model_executor/models/mixtral.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 652ad21d12ab4..fa74ec94b9644 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -530,7 +530,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -539,6 +539,20 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + def sample( self, logits: Optional[torch.Tensor], From c484a3766a79bcbd0f7ed2e5e2f63efde620f0c0 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:13:10 -0700 Subject: [PATCH 63/96] Building quantMixtralModel --- vllm/model_executor/models/mixtral.py | 76 +++++++++++++++++++++------ 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fa74ec94b9644..43fadf22396b5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -323,6 +323,50 @@ def forward( return hidden_states +class QuantMixtralModel(nn.Module): + def __init__( + self, + config: MixtralConfig, + use_fused_moe: bool, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + MixtralDecoderLayer( + config, use_fused_moe, cache_config, quant_config=quant_config + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, kv_caches[i], attn_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + class MixtralForCausalLM(nn.Module, SupportsLoRA): fall_back_to_pt_during_load = False @@ -513,8 +557,8 @@ def __init__( self.config = config self.quant_config = quant_config - self.model = MixtralModel( - config, cache_config, quant_config, None, prefix="model" + self.model = QuantMixtralModel( + config, self.use_fused_moe, cache_config, quant_config ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config @@ -530,7 +574,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) return hidden_states def compute_logits( @@ -539,19 +583,19 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, device: torch.device - ) -> IntermediateTensors: - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - "residual": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - } - ) + # def make_empty_intermediate_tensors( + # self, batch_size: int, dtype: torch.dtype, device: torch.device + # ) -> IntermediateTensors: + # return IntermediateTensors( + # { + # "hidden_states": torch.zeros( + # (batch_size, self.config.hidden_size), dtype=dtype, device=device + # ), + # "residual": torch.zeros( + # (batch_size, self.config.hidden_size), dtype=dtype, device=device + # ), + # } + # ) def sample( self, From 0344e72750fea4e1916bfbeed57c0db11f51fff8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:22:15 -0700 Subject: [PATCH 64/96] fused moe test --- vllm/model_executor/models/mixtral_quant.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 85dafd55bbcf8..bf3f15d072eff 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput - +import logging +logger = logging.getLogger(__name__) class MixtralMLP(nn.Module): def __init__( @@ -371,7 +372,7 @@ def __init__( # TODO check runs with dtype=float16 self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) - + logger.error(f"Using fused MoE: {self.use_fused_moe}") self.config = config self.quant_config = quant_config self.model = MixtralModel(config, self.use_fused_moe, cache_config, From 8c8b3fa774e8eed90f23d896afabfe9fb0c81ab7 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:26:26 -0700 Subject: [PATCH 65/96] Lora enabled mixtral --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/mixtral_quant.py | 177 +++++++++++++++++++- 2 files changed, 176 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 72a13d13eb0d6..2954d8874c9a1 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral", "QuantizedMixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "LoRAEnabledMixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index bf3f15d072eff..1801d1b5e0290 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,8 +30,9 @@ from torch import nn from transformers import MixtralConfig +from .interfaces import SupportsLoRA from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -46,7 +47,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -484,3 +485,175 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) +class LoRAEnabledMixtralForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = MixtralModel( + config, cache_config, quant_config, lora_config=lora_config, prefix="model" + ) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, kv_caches, attn_metadata, intermediate_tensors + ) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if self.use_fused_moe: + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): + continue + + if (".qzeros" in name): + continue + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") + + else: + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + + param = params_dict[name] + + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file From dff59cdbcec1dc975dd1c0809ee4e1929c1e5cf5 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:28:38 -0700 Subject: [PATCH 66/96] LoRAMixtralModel compat --- vllm/model_executor/models/mixtral_quant.py | 74 ++++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 1801d1b5e0290..ae5559d4297d9 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -31,9 +31,11 @@ from transformers import MixtralConfig from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers + from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE @@ -360,6 +362,74 @@ def forward( return hidden_states + +class LoRAMixtralModel(nn.Module): + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + class MixtralForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -521,7 +591,7 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = MixtralModel( + self.model = LoRAMixtralModel( config, cache_config, quant_config, lora_config=lora_config, prefix="model" ) self.unpadded_vocab_size = config.vocab_size From 33f7e515a50041ab943f66ee88687c8c2de7d673 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:30:17 -0700 Subject: [PATCH 67/96] remove prefix --- vllm/model_executor/models/mixtral_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ae5559d4297d9..08049770d5e6b 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -391,7 +391,7 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix + config, cache_config, quant_config=quant_config ), prefix=f"{prefix}.layers", ) From fdba91766bae9f6d234709f6433aeecfbc07f737 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 11:31:11 -0700 Subject: [PATCH 68/96] use fused moe --- vllm/model_executor/models/mixtral_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 08049770d5e6b..056ba105531d9 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -590,7 +590,7 @@ def __init__( self.config = config self.lora_config = lora_config - + self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.model = LoRAMixtralModel( config, cache_config, quant_config, lora_config=lora_config, prefix="model" ) From 780471ebf1263109923fffe134f7c22bf3bc53f7 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 12:25:15 -0700 Subject: [PATCH 69/96] remove org num embeddings --- vllm/model_executor/models/mixtral_quant.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 056ba105531d9..3374530086499 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -600,7 +600,6 @@ def __init__( self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, - org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility @@ -608,7 +607,7 @@ def __init__( quant_config=quant_config, ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size + self.unpadded_vocab_size ) self.sampler = Sampler() From c0970f1d30a11eef359f8c7210492e06381215cf Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 12:29:43 -0700 Subject: [PATCH 70/96] pass use fused moe into decoder --- vllm/model_executor/models/mixtral_quant.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 3374530086499..c0837802ce318 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -391,7 +391,7 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config + config, use_fused_moe=True, cache_config=cache_config, quant_config=quant_config ), prefix=f"{prefix}.layers", ) @@ -592,7 +592,7 @@ def __init__( self.lora_config = lora_config self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.model = LoRAMixtralModel( - config, cache_config, quant_config, lora_config=lora_config, prefix="model" + config=config, cache_config=cache_config, quant_config=quant_config, lora_config=lora_config, prefix="model" ) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -600,6 +600,7 @@ def __init__( self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility @@ -607,7 +608,7 @@ def __init__( quant_config=quant_config, ) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size + self.unpadded_vocab_size, config.vocab_size ) self.sampler = Sampler() From 6a1a8387346868fe1be6df24cb61648ac49ff6ec Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:15:06 -0700 Subject: [PATCH 71/96] Mixtral for causal lm load func --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/mixtral.py | 169 ++++++++++++++++++------- 2 files changed, 123 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 2954d8874c9a1..329df4830af41 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -48,7 +48,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "LoRAEnabledMixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 43fadf22396b5..e5f12d2eb8648 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -402,7 +402,7 @@ def __init__( self.config = config self.lora_config = lora_config - + self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.model = MixtralModel( config, cache_config, quant_config, lora_config=lora_config, prefix="model" ) @@ -465,6 +465,80 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # stacked_params_mapping = [ + # # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] + + # # Params for weights, fp8 weight scales, fp8 activation scales + # # (param_name, weight_name, expert_id, shard_id) + # expert_params_mapping = FusedMoE.make_expert_params_mapping( + # ckpt_gate_proj_name="w1", + # ckpt_down_proj_name="w2", + # ckpt_up_proj_name="w3", + # num_experts=self.config.num_local_experts, + # ) + + # params_dict = dict(self.named_parameters()) + # for name, loaded_weight in weights: + # if "rotary_emb.inv_freq" in name: + # continue + + # for param_name, weight_name, shard_id in stacked_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, loaded_weight, shard_id, is_quantized=True) + # break + # else: + # for mapping in expert_params_mapping: + # param_name, weight_name, expert_id, shard_id = mapping + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader( + # param, + # loaded_weight, + # weight_name, + # shard_id=shard_id, + # expert_id=expert_id, + # is_quantized=True, + # ) + # break + # else: + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + # # Remapping the name of FP8 kv-scale. + # name = maybe_remap_kv_scale_name(name, params_dict) + # if name is None: + # continue + + # param = params_dict[name] + # weight_loader = getattr( + # param, "weight_loader", default_weight_loader + # ) + # weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -473,71 +547,72 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts, - ) - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - - for param_name, weight_name, shard_id in stacked_params_mapping: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, is_quantized=True) + weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if self.use_fused_moe: + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + + if (".qzeros" in name): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - weight_name, - shard_id=shard_id, - expert_id=expert_id, - is_quantized=True, - ) - break + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") + else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: + if ("block_sparse_moe.experts." in name + and name not in params_dict): continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + param = params_dict[name] + + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) From 5c3e857163c53e0c1985bec070233b8420a9a407 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:27:51 -0700 Subject: [PATCH 72/96] Copying over quant mixtral --- vllm/model_executor/models/mixtral.py | 383 +++++++++++++++++--------- 1 file changed, 256 insertions(+), 127 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e5f12d2eb8648..8f577a0beb420 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -24,12 +24,14 @@ from typing import Iterable, List, Optional, Tuple import re import torch +import numpy as np +import torch.nn.functional as F from torch import nn from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -112,6 +114,130 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states +class QuantMixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + use_fused_moe: bool, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.use_fused_moe = use_fused_moe + self.quant_config = quant_config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + if self.use_fused_moe: + params_dtype = torch.float16 + self.experts = FusedMoE(num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=self.tp_size) + else: + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits, _ = self.gate(hidden_states) + + if self.use_fused_moe: + ret = self.experts(hidden_states.half(), router_logits) + return ret.bfloat16() + else: + routing_weights = F.softmax(router_logits, + dim=1, + dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) + class MixtralAttention(nn.Module): def __init__( @@ -216,13 +342,16 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.block_sparse_moe = MixtralMoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe", + # self.block_sparse_moe = MixtralMoE( + # num_experts=config.num_local_experts, + # top_k=config.num_experts_per_tok, + # hidden_size=config.hidden_size, + # intermediate_size=config.intermediate_size, + # quant_config=quant_config, + # prefix=f"{prefix}.block_sparse_moe", + # ) + self.block_sparse_moe = QuantMixtralMoE( + config, use_fused_moe=True, quant_config=quant_config ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -465,80 +594,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # stacked_params_mapping = [ - # # (param_name, shard_name, shard_id) - # ("qkv_proj", "q_proj", "q"), - # ("qkv_proj", "k_proj", "k"), - # ("qkv_proj", "v_proj", "v"), - # ] - - # # Params for weights, fp8 weight scales, fp8 activation scales - # # (param_name, weight_name, expert_id, shard_id) - # expert_params_mapping = FusedMoE.make_expert_params_mapping( - # ckpt_gate_proj_name="w1", - # ckpt_down_proj_name="w2", - # ckpt_up_proj_name="w3", - # num_experts=self.config.num_local_experts, - # ) - - # params_dict = dict(self.named_parameters()) - # for name, loaded_weight in weights: - # if "rotary_emb.inv_freq" in name: - # continue - - # for param_name, weight_name, shard_id in stacked_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, loaded_weight, shard_id, is_quantized=True) - # break - # else: - # for mapping in expert_params_mapping: - # param_name, weight_name, expert_id, shard_id = mapping - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader( - # param, - # loaded_weight, - # weight_name, - # shard_id=shard_id, - # expert_id=expert_id, - # is_quantized=True, - # ) - # break - # else: - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - # # Remapping the name of FP8 kv-scale. - # name = maybe_remap_kv_scale_name(name, params_dict) - # if name is None: - # continue - - # param = params_dict[name] - # weight_loader = getattr( - # param, "weight_loader", default_weight_loader - # ) - # weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -547,73 +602,147 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, loaded_weight, shard_id, is_quantized=True) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - - if (".qzeros" in name): + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + is_quantized=True, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue - shard_id = None - expert_id = 0 + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # stacked_params_mapping = [ + # # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 + # params_dict = dict(self.named_parameters()) + # for name, loaded_weight in weights: + # if "rotary_emb.inv_freq" in name: + # continue + # for (param_name, weight_name, shard_id) in stacked_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, loaded_weight, shard_id) + # break + # else: + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") + # if self.use_fused_moe: + # if ("block_sparse_moe.experts." in name + # and ".w1." not in name and ".w2." not in name + # and ".w3." not in name + # and name not in params_dict): + # continue - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue + # if (".qzeros" in name): + # continue - param = params_dict[name] + # shard_id = None + # expert_id = 0 + + # has_any_numbered = (".qweight" in name or ".scales" in name + # or ".g_idx" in name) + # if (has_any_numbered and (".w1." in name)): + # name = name.replace(".w1.", ".w13_") + # shard_id = 0 + # if (has_any_numbered and (".w2." in name)): + # name = name.replace(".w2.", ".w2_") + # shard_id = 0 + # if (has_any_numbered and (".w3." in name)): + # name = name.replace(".w3.", ".w13_") + # shard_id = 1 + + # exp_string = re.search(r"\.experts\.\d+.", name) + # if exp_string: + # exp_string = exp_string.group(0) + # expert_id = int(exp_string.split(".")[2]) + # name = name.replace(exp_string, ".experts.") - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # else: + # if ("block_sparse_moe.experts." in name + # and name not in params_dict): + # continue + + # param = params_dict[name] + + # if self.use_fused_moe and shard_id is not None: + # weight_loader = getattr(param, "weight_loader", + # default_weight_loader) + # weight_loader(param, loaded_weight, name, shard_id, + # expert_id, True) + # else: + # weight_loader = getattr(param, "weight_loader", + # default_weight_loader) + # weight_loader(param, loaded_weight) class QuantizedMixtralForCausalLM(nn.Module): From 8d327ded68e6eb992b083af003433cc739638ae3 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:31:26 -0700 Subject: [PATCH 73/96] Passing prefix --- vllm/model_executor/models/mixtral.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8f577a0beb420..983a74be14513 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -158,6 +158,7 @@ def __init__( config: MixtralConfig, use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -188,7 +189,8 @@ def __init__( reduce_results=True, renormalize=True, quant_config=quant_config, - tp_size=self.tp_size) + tp_size=self.tp_size, + prefix=f"{prefix}.experts") else: self.experts = nn.ModuleList([ MixtralMLP(self.num_total_experts, @@ -202,7 +204,8 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - quant_config=None) + quant_config=None, + prefix=f"{prefix}.gate") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -351,7 +354,7 @@ def __init__( # prefix=f"{prefix}.block_sparse_moe", # ) self.block_sparse_moe = QuantMixtralMoE( - config, use_fused_moe=True, quant_config=quant_config + config, use_fused_moe=True, quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( From d337aeab98b5968d6c35cea6d004caa83d1d27cf Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:32:33 -0700 Subject: [PATCH 74/96] Weight load --- vllm/model_executor/models/mixtral.py | 365 ++++++++++++++------------ 1 file changed, 192 insertions(+), 173 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 983a74be14513..a656782307d47 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -31,7 +31,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -114,8 +119,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) -class MixtralMLP(nn.Module): +class MixtralMLP(nn.Module): def __init__( self, num_experts: int, @@ -128,18 +133,15 @@ def __init__( self.ffn_dim = intermediate_size self.hidden_dim = hidden_size - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) + self.w1 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) + self.w2 = ReplicatedLinear( + self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config + ) + self.w3 = ReplicatedLinear( + self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + ) # TODO: Use vllm's SiluAndMul self.act_fn = nn.SiLU() @@ -151,8 +153,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: current_hidden_states = w1_out * w3_out current_hidden_states, _ = self.w2(current_hidden_states) return current_hidden_states -class QuantMixtralMoE(nn.Module): + +class QuantMixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, @@ -171,41 +174,51 @@ def __init__( if self.tp_size > self.num_total_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") + f"the number of experts {self.num_total_experts}." + ) # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() + self.expert_indicies = np.array_split( + range(self.num_total_experts), self.tp_size + )[self.rank].tolist() if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") + raise ValueError(f"Rank {self.rank} has no experts assigned to it.") if self.use_fused_moe: params_dtype = torch.float16 - self.experts = FusedMoE(num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size, - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=self.tp_size, + prefix=f"{prefix}.experts", + ) else: - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.experts = nn.ModuleList( + [ + MixtralMLP( + self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + if idx in self.expert_indicies + else None + for idx in range(self.num_total_experts) + ] + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -216,30 +229,29 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ret = self.experts(hidden_states.half(), router_logits) return ret.bfloat16() else: - routing_weights = F.softmax(router_logits, - dim=1, - dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) final_hidden_states = None for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) + expert_mask = selected_experts == expert_idx expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) + dim=-1, keepdim=True + ) - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) + current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) if final_hidden_states is None: final_hidden_states = current_hidden_states else: final_hidden_states.add_(current_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + num_tokens, hidden_dim + ) class MixtralAttention(nn.Module): @@ -354,7 +366,10 @@ def __init__( # prefix=f"{prefix}.block_sparse_moe", # ) self.block_sparse_moe = QuantMixtralMoE( - config, use_fused_moe=True, quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", + config, + use_fused_moe=True, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -534,7 +549,7 @@ def __init__( self.config = config self.lora_config = lora_config - self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) + self.use_fused_moe = config.torch_dtype != torch.float8_e4m3fn self.model = MixtralModel( config, cache_config, quant_config, lora_config=lora_config, prefix="model" ) @@ -597,6 +612,80 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # stacked_params_mapping = [ + # # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] + + # # Params for weights, fp8 weight scales, fp8 activation scales + # # (param_name, weight_name, expert_id, shard_id) + # expert_params_mapping = FusedMoE.make_expert_params_mapping( + # ckpt_gate_proj_name="w1", + # ckpt_down_proj_name="w2", + # ckpt_up_proj_name="w3", + # num_experts=self.config.num_local_experts, + # ) + + # params_dict = dict(self.named_parameters()) + # for name, loaded_weight in weights: + # if "rotary_emb.inv_freq" in name: + # continue + + # for param_name, weight_name, shard_id in stacked_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, loaded_weight, shard_id, is_quantized=True) + # break + # else: + # for mapping in expert_params_mapping: + # param_name, weight_name, expert_id, shard_id = mapping + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader( + # param, + # loaded_weight, + # weight_name, + # shard_id=shard_id, + # expert_id=expert_id, + # is_quantized=True, + # ) + # break + # else: + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + # # Remapping the name of FP8 kv-scale. + # name = maybe_remap_kv_scale_name(name, params_dict) + # if name is None: + # continue + + # param = params_dict[name] + # weight_loader = getattr( + # param, "weight_loader", default_weight_loader + # ) + # weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -605,20 +694,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts, - ) - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -626,126 +705,66 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, is_quantized=True) + weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if self.use_fused_moe: + if ( + "block_sparse_moe.experts." in name + and ".w1." not in name + and ".w2." not in name + and ".w3." not in name + and name not in params_dict + ): continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + + if ".qzeros" in name: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - weight_name, - shard_id=shard_id, - expert_id=expert_id, - is_quantized=True, + + shard_id = None + expert_id = 0 + + has_any_numbered = ( + ".qweight" in name or ".scales" in name or ".g_idx" in name ) - break + if has_any_numbered and (".w1." in name): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if has_any_numbered and (".w2." in name): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if has_any_numbered and (".w3." in name): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") + else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: + if "block_sparse_moe.experts." in name and name not in params_dict: continue - param = params_dict[name] + param = params_dict[name] + + if self.use_fused_moe and shard_id is not None: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, name, shard_id, expert_id, True) + else: weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) - # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # stacked_params_mapping = [ - # # (param_name, shard_name, shard_id) - # ("qkv_proj", "q_proj", "q"), - # ("qkv_proj", "k_proj", "k"), - # ("qkv_proj", "v_proj", "v"), - # ] - - # params_dict = dict(self.named_parameters()) - # for name, loaded_weight in weights: - # if "rotary_emb.inv_freq" in name: - # continue - # for (param_name, weight_name, shard_id) in stacked_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, loaded_weight, shard_id) - # break - # else: - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - - # if self.use_fused_moe: - # if ("block_sparse_moe.experts." in name - # and ".w1." not in name and ".w2." not in name - # and ".w3." not in name - # and name not in params_dict): - # continue - - # if (".qzeros" in name): - # continue - - # shard_id = None - # expert_id = 0 - - # has_any_numbered = (".qweight" in name or ".scales" in name - # or ".g_idx" in name) - # if (has_any_numbered and (".w1." in name)): - # name = name.replace(".w1.", ".w13_") - # shard_id = 0 - # if (has_any_numbered and (".w2." in name)): - # name = name.replace(".w2.", ".w2_") - # shard_id = 0 - # if (has_any_numbered and (".w3." in name)): - # name = name.replace(".w3.", ".w13_") - # shard_id = 1 - - # exp_string = re.search(r"\.experts\.\d+.", name) - # if exp_string: - # exp_string = exp_string.group(0) - # expert_id = int(exp_string.split(".")[2]) - # name = name.replace(exp_string, ".experts.") - - # else: - # if ("block_sparse_moe.experts." in name - # and name not in params_dict): - # continue - - # param = params_dict[name] - - # if self.use_fused_moe and shard_id is not None: - # weight_loader = getattr(param, "weight_loader", - # default_weight_loader) - # weight_loader(param, loaded_weight, name, shard_id, - # expert_id, True) - # else: - # weight_loader = getattr(param, "weight_loader", - # default_weight_loader) - # weight_loader(param, loaded_weight) class QuantizedMixtralForCausalLM(nn.Module): From 379f3e82cb4f1c5a60f2ae6a229a07e41a6c4aeb Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:40:08 -0700 Subject: [PATCH 75/96] Weight load back --- vllm/model_executor/models/mixtral.py | 414 +++++++------------------- 1 file changed, 112 insertions(+), 302 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a656782307d47..64572444ee0f7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -469,51 +469,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class QuantMixtralModel(nn.Module): - def __init__( - self, - config: MixtralConfig, - use_fused_moe: bool, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList( - [ - MixtralDecoderLayer( - config, use_fused_moe, cache_config, quant_config=quant_config - ) - for _ in range(config.num_hidden_layers) - ] - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, kv_caches[i], attn_metadata, residual - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - class MixtralForCausalLM(nn.Module, SupportsLoRA): fall_back_to_pt_during_load = False @@ -612,80 +567,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # stacked_params_mapping = [ - # # (param_name, shard_name, shard_id) - # ("qkv_proj", "q_proj", "q"), - # ("qkv_proj", "k_proj", "k"), - # ("qkv_proj", "v_proj", "v"), - # ] - - # # Params for weights, fp8 weight scales, fp8 activation scales - # # (param_name, weight_name, expert_id, shard_id) - # expert_params_mapping = FusedMoE.make_expert_params_mapping( - # ckpt_gate_proj_name="w1", - # ckpt_down_proj_name="w2", - # ckpt_up_proj_name="w3", - # num_experts=self.config.num_local_experts, - # ) - - # params_dict = dict(self.named_parameters()) - # for name, loaded_weight in weights: - # if "rotary_emb.inv_freq" in name: - # continue - - # for param_name, weight_name, shard_id in stacked_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, loaded_weight, shard_id, is_quantized=True) - # break - # else: - # for mapping in expert_params_mapping: - # param_name, weight_name, expert_id, shard_id = mapping - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader( - # param, - # loaded_weight, - # weight_name, - # shard_id=shard_id, - # expert_id=expert_id, - # is_quantized=True, - # ) - # break - # else: - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # # Skip layers on other devices. - # if is_pp_missing_parameter(name, self): - # continue - # # Remapping the name of FP8 kv-scale. - # name = maybe_remap_kv_scale_name(name, params_dict) - # if name is None: - # continue - - # param = params_dict[name] - # weight_loader = getattr( - # param, "weight_loader", default_weight_loader - # ) - # weight_loader(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -694,10 +575,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -705,208 +596,127 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, loaded_weight, shard_id, is_quantized=True) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if self.use_fused_moe: - if ( - "block_sparse_moe.experts." in name - and ".w1." not in name - and ".w2." not in name - and ".w3." not in name - and name not in params_dict - ): + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - - if ".qzeros" in name: + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): continue - - shard_id = None - expert_id = 0 - - has_any_numbered = ( - ".qweight" in name or ".scales" in name or ".g_idx" in name + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + is_quantized=True, ) - if has_any_numbered and (".w1." in name): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if has_any_numbered and (".w2." in name): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if has_any_numbered and (".w3." in name): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - + break else: - if "block_sparse_moe.experts." in name and name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: continue - param = params_dict[name] - - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight, name, shard_id, expert_id, True) - else: + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) + # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # stacked_params_mapping = [ + # # (param_name, shard_name, shard_id) + # ("qkv_proj", "q_proj", "q"), + # ("qkv_proj", "k_proj", "k"), + # ("qkv_proj", "v_proj", "v"), + # ] + # params_dict = dict(self.named_parameters()) + # for name, loaded_weight in weights: + # if "rotary_emb.inv_freq" in name: + # continue + # for param_name, weight_name, shard_id in stacked_params_mapping: + # if weight_name not in name: + # continue + # name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + # param = params_dict[name] + # weight_loader = param.weight_loader + # weight_loader(param, loaded_weight, shard_id) + # break + # else: + # # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue -class QuantizedMixtralForCausalLM(nn.Module): - fall_back_to_pt_during_load = False - - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - - # TODO check runs with dtype=float16 - self.use_fused_moe = config.torch_dtype != torch.float8_e4m3fn - - self.config = config - self.quant_config = quant_config - self.model = QuantMixtralModel( - config, self.use_fused_moe, cache_config, quant_config - ) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) - return hidden_states - - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - return logits - - # def make_empty_intermediate_tensors( - # self, batch_size: int, dtype: torch.dtype, device: torch.device - # ) -> IntermediateTensors: - # return IntermediateTensors( - # { - # "hidden_states": torch.zeros( - # (batch_size, self.config.hidden_size), dtype=dtype, device=device - # ), - # "residual": torch.zeros( - # (batch_size, self.config.hidden_size), dtype=dtype, device=device - # ), - # } - # ) - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if self.use_fused_moe: - if ( - "block_sparse_moe.experts." in name - and ".w1." not in name - and ".w2." not in name - and ".w3." not in name - and name not in params_dict - ): - continue + # if self.use_fused_moe: + # if ( + # "block_sparse_moe.experts." in name + # and ".w1." not in name + # and ".w2." not in name + # and ".w3." not in name + # and name not in params_dict + # ): + # continue - if ".qzeros" in name: - continue + # if ".qzeros" in name: + # continue - shard_id = None - expert_id = 0 + # shard_id = None + # expert_id = 0 - has_any_numbered = ( - ".qweight" in name or ".scales" in name or ".g_idx" in name - ) - if has_any_numbered and (".w1." in name): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if has_any_numbered and (".w2." in name): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if has_any_numbered and (".w3." in name): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") + # has_any_numbered = ( + # ".qweight" in name or ".scales" in name or ".g_idx" in name + # ) + # if has_any_numbered and (".w1." in name): + # name = name.replace(".w1.", ".w13_") + # shard_id = 0 + # if has_any_numbered and (".w2." in name): + # name = name.replace(".w2.", ".w2_") + # shard_id = 0 + # if has_any_numbered and (".w3." in name): + # name = name.replace(".w3.", ".w13_") + # shard_id = 1 + + # exp_string = re.search(r"\.experts\.\d+.", name) + # if exp_string: + # exp_string = exp_string.group(0) + # expert_id = int(exp_string.split(".")[2]) + # name = name.replace(exp_string, ".experts.") - else: - if "block_sparse_moe.experts." in name and name not in params_dict: - continue + # else: + # if "block_sparse_moe.experts." in name and name not in params_dict: + # continue - param = params_dict[name] + # param = params_dict[name] - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight, name, shard_id, expert_id, True) - else: - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + # if self.use_fused_moe and shard_id is not None: + # weight_loader = getattr( + # param, "weight_loader", default_weight_loader + # ) + # weight_loader(param, loaded_weight, name, shard_id, expert_id, True) + # else: + # weight_loader = getattr( + # param, "weight_loader", default_weight_loader + # ) + # weight_loader(param, loaded_weight) \ No newline at end of file From a5d356ec8c09a5033f3222ffdeaf682c791fadcd Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:44:11 -0700 Subject: [PATCH 76/96] Load with name not weight name --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 64572444ee0f7..0a700a4c57f89 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -618,7 +618,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, is_quantized=True, From 62c0135d186dba941c2e35ccb5eb4c49891ef3e5 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:46:13 -0700 Subject: [PATCH 77/96] params dict should load from old name --- vllm/model_executor/models/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0a700a4c57f89..84b150a245ef0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -592,6 +592,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + param = params_dict[name] name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -600,7 +601,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if is_pp_missing_parameter(name, self): continue - param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id, is_quantized=True) break @@ -609,11 +609,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue param = params_dict[name] + name = name.replace(weight_name, param_name) weight_loader = param.weight_loader weight_loader( param, From d23c00c63796eaaa0b96dd703e993d7113a22bbb Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:47:33 -0700 Subject: [PATCH 78/96] logging name and parmas --- vllm/model_executor/models/mixtral.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 84b150a245ef0..b1fd87dd080c0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -62,7 +62,8 @@ from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers - +import logging +logger = logging.getLogger(__name__) class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert @@ -585,6 +586,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) params_dict = dict(self.named_parameters()) + logger.error(params_dict) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -612,6 +614,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + logger.error(name) param = params_dict[name] name = name.replace(weight_name, param_name) weight_loader = param.weight_loader From 6dda4475a2cd8936353eb33086989651dd424674 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:48:59 -0700 Subject: [PATCH 79/96] log expert parmas map --- vllm/model_executor/models/mixtral.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b1fd87dd080c0..2265d1ac590bd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -586,7 +586,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) params_dict = dict(self.named_parameters()) - logger.error(params_dict) + logger.error(params_dict.keys()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -594,8 +594,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name] name = name.replace(weight_name, param_name) + param = params_dict[name] # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -607,6 +607,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id, is_quantized=True) break else: + logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: @@ -614,9 +615,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + name = name.replace(weight_name, param_name) logger.error(name) param = params_dict[name] - name = name.replace(weight_name, param_name) weight_loader = param.weight_loader weight_loader( param, From 67ce7b65309a7cd9791957db114e7eb0403190a5 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 13:58:38 -0700 Subject: [PATCH 80/96] parity with prev commits --- vllm/model_executor/models/mixtral.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2265d1ac590bd..a11f2d8f31f8c 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -595,7 +595,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) - param = params_dict[name] # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -603,6 +602,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if is_pp_missing_parameter(name, self): continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id, is_quantized=True) break @@ -613,10 +613,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue # Skip layers on other devices. + name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue - name = name.replace(weight_name, param_name) - logger.error(name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( From bd933c975d8a327a79f0242996551ae55155b1d5 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:02:09 -0700 Subject: [PATCH 81/96] Adding qzeros to mapping --- vllm/model_executor/layers/fused_moe/layer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0b06ee86a308d..34bad93f052dc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -399,4 +399,15 @@ def make_expert_params_mapping( shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the g_idx and g_idx_sort_indices scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qzeros" + if weight_name in gate_up else "experts.w2_qzeros", + f"experts.{expert_id}.{weight_name}.qzeros", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) ]) From 77cd09561efa4798c26788dfff1352cfa08b39fa Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:05:14 -0700 Subject: [PATCH 82/96] Remove log --- vllm/model_executor/models/mixtral.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a11f2d8f31f8c..fc170cee534cf 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -586,7 +586,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) params_dict = dict(self.named_parameters()) - logger.error(params_dict.keys()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -607,7 +606,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id, is_quantized=True) break else: - logger.error(expert_params_mapping) for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: From 529191eb7a1b8fe55a9d56bee373b348f825c4df Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:05:47 -0700 Subject: [PATCH 83/96] Remove is quantized --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fc170cee534cf..fce9305008abd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -603,7 +603,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, is_quantized=True) + weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: From 2450543233960ca5bfaaa3f21d38a863d2ece851 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:21:32 -0700 Subject: [PATCH 84/96] Assume fused true --- vllm/model_executor/models/mixtral.py | 194 +++----------------------- 1 file changed, 18 insertions(+), 176 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index fce9305008abd..a6016b916c883 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -120,53 +120,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) - -class MixtralMLP(nn.Module): - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config - ) - self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - class QuantMixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config - self.use_fused_moe = use_fused_moe self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -184,35 +146,7 @@ def __init__( if not self.expert_indicies: raise ValueError(f"Rank {self.rank} has no experts assigned to it.") - if self.use_fused_moe: - params_dtype = torch.float16 - self.experts = FusedMoE( - num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size, - prefix=f"{prefix}.experts", - ) - else: - self.experts = nn.ModuleList( - [ - MixtralMLP( - self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - ) - if idx in self.expert_indicies - else None - for idx in range(self.num_total_experts) - ] - ) - + params_dtype = torch.float16 self.gate = ReplicatedLinear( config.hidden_size, self.num_total_experts, @@ -220,40 +154,27 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) + self.experts = FusedMoE( + num_experts=self.num_total_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=self.tp_size, + prefix=f"{prefix}.experts", + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape + _, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) router_logits, _ = self.gate(hidden_states) - if self.use_fused_moe: - ret = self.experts(hidden_states.half(), router_logits) - return ret.bfloat16() - else: - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = selected_experts == expert_idx - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True - ) - - current_hidden_states = expert_layer(hidden_states).mul_(expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim - ) - + ret = self.experts(hidden_states.half(), router_logits) + return ret.bfloat16() class MixtralAttention(nn.Module): def __init__( @@ -641,83 +562,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr( param, "weight_loader", default_weight_loader ) - weight_loader(param, loaded_weight) - # def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # stacked_params_mapping = [ - # # (param_name, shard_name, shard_id) - # ("qkv_proj", "q_proj", "q"), - # ("qkv_proj", "k_proj", "k"), - # ("qkv_proj", "v_proj", "v"), - # ] - - # params_dict = dict(self.named_parameters()) - # for name, loaded_weight in weights: - # if "rotary_emb.inv_freq" in name: - # continue - # for param_name, weight_name, shard_id in stacked_params_mapping: - # if weight_name not in name: - # continue - # name = name.replace(weight_name, param_name) - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - # param = params_dict[name] - # weight_loader = param.weight_loader - # weight_loader(param, loaded_weight, shard_id) - # break - # else: - # # Skip loading extra bias for GPTQ models. - # if name.endswith(".bias") and name not in params_dict: - # continue - - # if self.use_fused_moe: - # if ( - # "block_sparse_moe.experts." in name - # and ".w1." not in name - # and ".w2." not in name - # and ".w3." not in name - # and name not in params_dict - # ): - # continue - - # if ".qzeros" in name: - # continue - - # shard_id = None - # expert_id = 0 - - # has_any_numbered = ( - # ".qweight" in name or ".scales" in name or ".g_idx" in name - # ) - # if has_any_numbered and (".w1." in name): - # name = name.replace(".w1.", ".w13_") - # shard_id = 0 - # if has_any_numbered and (".w2." in name): - # name = name.replace(".w2.", ".w2_") - # shard_id = 0 - # if has_any_numbered and (".w3." in name): - # name = name.replace(".w3.", ".w13_") - # shard_id = 1 - - # exp_string = re.search(r"\.experts\.\d+.", name) - # if exp_string: - # exp_string = exp_string.group(0) - # expert_id = int(exp_string.split(".")[2]) - # name = name.replace(exp_string, ".experts.") - - # else: - # if "block_sparse_moe.experts." in name and name not in params_dict: - # continue - - # param = params_dict[name] - - # if self.use_fused_moe and shard_id is not None: - # weight_loader = getattr( - # param, "weight_loader", default_weight_loader - # ) - # weight_loader(param, loaded_weight, name, shard_id, expert_id, True) - # else: - # weight_loader = getattr( - # param, "weight_loader", default_weight_loader - # ) - # weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) \ No newline at end of file From 8cba45e2de78fd78b9aef7676ed70deef9b4d5c4 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:22:10 -0700 Subject: [PATCH 85/96] rm fused true --- vllm/model_executor/models/mixtral.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a6016b916c883..5861cf6df9bc8 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -289,7 +289,6 @@ def __init__( # ) self.block_sparse_moe = QuantMixtralMoE( config, - use_fused_moe=True, quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", ) @@ -426,7 +425,6 @@ def __init__( self.config = config self.lora_config = lora_config - self.use_fused_moe = config.torch_dtype != torch.float8_e4m3fn self.model = MixtralModel( config, cache_config, quant_config, lora_config=lora_config, prefix="model" ) From 10940a5c503d850f8a15f08359a0e88593cabdbd Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:40:31 -0700 Subject: [PATCH 86/96] Switching to mixtral moe --- vllm/model_executor/models/mixtral.py | 32 +++++++++------------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 5861cf6df9bc8..adc94a226c9d1 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -134,17 +134,6 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}." - ) - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), self.tp_size - )[self.rank].tolist() - if not self.expert_indicies: - raise ValueError(f"Rank {self.rank} has no experts assigned to it.") params_dtype = torch.float16 self.gate = ReplicatedLinear( @@ -279,19 +268,20 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - # self.block_sparse_moe = MixtralMoE( - # num_experts=config.num_local_experts, - # top_k=config.num_experts_per_tok, - # hidden_size=config.hidden_size, - # intermediate_size=config.intermediate_size, - # quant_config=quant_config, - # prefix=f"{prefix}.block_sparse_moe", - # ) - self.block_sparse_moe = QuantMixtralMoE( - config, + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), prefix=f"{prefix}.block_sparse_moe", ) + # self.block_sparse_moe = QuantMixtralMoE( + # config, + # quant_config=quant_config, + # prefix=f"{prefix}.block_sparse_moe", + # ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps From 895ffbe2704794574c38a7dfc6352e810b0d538e Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:49:17 -0700 Subject: [PATCH 87/96] Precision changes --- vllm/model_executor/models/mixtral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index adc94a226c9d1..7f2109e337e27 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -93,7 +93,6 @@ def __init__( hidden_size, num_experts, bias=False, - params_dtype=params_dtype, quant_config=None, prefix=f"{prefix}.gate", ) @@ -117,8 +116,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape) + final_hidden_states = self.experts(hidden_states.half(), router_logits) + return final_hidden_states.view(orig_shape).bfloat16() class QuantMixtralMoE(nn.Module): def __init__( @@ -275,6 +274,7 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, tp_size=get_tensor_model_parallel_world_size(), + params_dtype=torch.float16, prefix=f"{prefix}.block_sparse_moe", ) # self.block_sparse_moe = QuantMixtralMoE( From e54b2e4d47e9d748013fc4aede3ac93efd8278ec Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 14:58:19 -0700 Subject: [PATCH 88/96] Cleanup --- vllm/model_executor/models/mixtral.py | 51 --------------------------- 1 file changed, 51 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7f2109e337e27..9f01f9349335e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -118,52 +118,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states.half(), router_logits) return final_hidden_states.view(orig_shape).bfloat16() - -class QuantMixtralMoE(nn.Module): - def __init__( - self, - config: MixtralConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.quant_config = quant_config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - params_dtype = torch.float16 - self.gate = ReplicatedLinear( - config.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - self.experts = FusedMoE( - num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size, - prefix=f"{prefix}.experts", - ) - - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - _, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - router_logits, _ = self.gate(hidden_states) - - ret = self.experts(hidden_states.half(), router_logits) - return ret.bfloat16() - class MixtralAttention(nn.Module): def __init__( self, @@ -277,11 +231,6 @@ def __init__( params_dtype=torch.float16, prefix=f"{prefix}.block_sparse_moe", ) - # self.block_sparse_moe = QuantMixtralMoE( - # config, - # quant_config=quant_config, - # prefix=f"{prefix}.block_sparse_moe", - # ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps From b4f23dc6b0a4677fc3bf137576afe93e25b2184b Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 15:03:19 -0700 Subject: [PATCH 89/96] Mixtral quant parity: --- vllm/model_executor/models/mixtral_quant.py | 241 -------------------- 1 file changed, 241 deletions(-) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index c0837802ce318..c9143552224f5 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -361,75 +361,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - - -class LoRAMixtralModel(nn.Module): - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: MixtralDecoderLayer( - config, use_fused_moe=True, cache_config=cache_config, quant_config=quant_config - ), - prefix=f"{prefix}.layers", - ) - - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: - if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - class MixtralForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -488,178 +419,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - - param = params_dict[name] - - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) -class LoRAEnabledMixtralForCausalLM(nn.Module, SupportsLoRA): - fall_back_to_pt_during_load = False - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - - self.config = config - self.lora_config = lora_config - self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) - self.model = LoRAMixtralModel( - config=config, cache_config=cache_config, quant_config=quant_config, lora_config=lora_config, prefix="model" - ) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) - self.sampler = Sampler() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, kv_caches, attn_metadata, intermediate_tensors - ) - return hidden_states - - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - return logits - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, device: torch.device - ) -> IntermediateTensors: - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - "residual": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - } - ) - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: From d59fe3b166ad597517f451fb4d30d5223851517f Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 15:08:08 -0700 Subject: [PATCH 90/96] fixing tests --- tests/kernels/test_moe.py | 1 + vllm/model_executor/layers/fused_moe/fused_moe_marlin.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index e657581df05a0..b53f578988214 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -239,6 +239,7 @@ def test_fused_marlin_moe( renormalize=False, w1_scale=scales1, w2_scale=scales2, + num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index efafcef2f1ee7..3e080a1393e16 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -120,7 +120,7 @@ def fused_moe_marlin( False, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, (num_bits // 2) * N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache2, From 0d9cbdc95a12524b750ab3f39aa4fffa7f692d29 Mon Sep 17 00:00:00 2001 From: Dhruva Bansal Date: Thu, 15 Aug 2024 23:24:02 +0000 Subject: [PATCH 91/96] Tests working and correctness verified --- vllm/model_executor/models/mixtral.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9f01f9349335e..9d99606b49ec3 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -80,20 +80,21 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, + params_dtype: Optional[torch.dtype] = torch.float16, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", ): super().__init__() self.hidden_size = hidden_size - + self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( hidden_size, num_experts, bias=False, quant_config=None, + params_dtype=params_dtype, prefix=f"{prefix}.gate", ) @@ -112,12 +113,12 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + orig_shape, orig_type = hidden_states.shape, hidden_states.dtype + hidden_states = hidden_states.view(-1, self.hidden_size).to(self.params_dtype) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states.half(), router_logits) - return final_hidden_states.view(orig_shape).bfloat16() + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape).to(orig_type) class MixtralAttention(nn.Module): def __init__( self, From 112aa40fd31fcfe431581716c25f732f1b482c04 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Thu, 15 Aug 2024 16:29:51 -0700 Subject: [PATCH 92/96] Formating --- .../layers/fused_moe/fused_moe_marlin.py | 4 +- .../layers/quantization/gptq_marlin.py | 3 - vllm/model_executor/models/mixtral.py | 96 ++++++++++--------- vllm/model_executor/models/mixtral_quant.py | 7 +- 4 files changed, 59 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 3e080a1393e16..7e834e3250f74 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -56,8 +56,8 @@ def fused_moe_marlin( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // (num_bits // 2), "Hidden size mismatch w2" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 90762efef8108..64453d61145ef 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -539,8 +539,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device=device), requires_grad=False, ) - # logger.error(f"W13 qweight size - {layer.w13_qweight.size()}") - # logger.error(f"Quant Config: {self.quant_config}") # Repack weights marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, @@ -567,7 +565,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) replace_tensor(layer, "w13_scales", marlin_w13_scales) - # logger.error(f"{layer.w2_scales.size()}, {layer.intermediate_size_per_partition}, {self.quant_config.group_size}") marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9d99606b49ec3..f157a21fd27ae 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -63,8 +63,10 @@ from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers import logging + logger = logging.getLogger(__name__) + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -114,12 +116,16 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape, orig_type = hidden_states.shape, hidden_states.dtype - hidden_states = hidden_states.view(-1, self.hidden_size).to(self.params_dtype) + hidden_states = hidden_states.view(-1, self.hidden_size).to( + self.params_dtype) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape).to(orig_type) + + class MixtralAttention(nn.Module): + def __init__( self, hidden_size: int, @@ -201,6 +207,7 @@ def forward( class MixtralDecoderLayer(nn.Module): + def __init__( self, config: MixtralConfig, @@ -232,10 +239,10 @@ def __init__( params_dtype=torch.float16, prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -250,7 +257,8 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -259,12 +267,14 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual class MixtralModel(nn.Module): + def __init__( self, config: MixtralConfig, @@ -275,11 +285,8 @@ def __init__( ) -> None: super().__init__() self.padding_idx = config.pad_token_id - lora_vocab = ( - (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) - if lora_config - else 0 - ) + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -324,12 +331,14 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + class MixtralForCausalLM(nn.Module, SupportsLoRA): fall_back_to_pt_during_load = False @@ -365,9 +374,11 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = MixtralModel( - config, cache_config, quant_config, lora_config=lora_config, prefix="model" - ) + self.model = MixtralModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -381,9 +392,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, config.vocab_size - ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( @@ -394,30 +404,29 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, kv_caches, attn_metadata, intermediate_tensors - ) + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata - ) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, device: torch.device - ) -> IntermediateTensors: - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - "residual": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - } - ) + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) def sample( self, @@ -497,7 +506,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index c9143552224f5..3ff70d222b518 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -55,7 +55,10 @@ from vllm.sequence import IntermediateTensors, SamplerOutput import logging + logger = logging.getLogger(__name__) + + class MixtralMLP(nn.Module): def __init__( @@ -361,6 +364,7 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + class MixtralForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -374,7 +378,6 @@ def __init__( # TODO check runs with dtype=float16 self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) - logger.error(f"Using fused MoE: {self.use_fused_moe}") self.config = config self.quant_config = quant_config self.model = MixtralModel(config, self.use_fused_moe, cache_config, @@ -485,4 +488,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) \ No newline at end of file + weight_loader(param, loaded_weight) From 1ca90987b6c8a2266217a8bc863d9d4e834ba012 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 19 Aug 2024 10:32:59 -0700 Subject: [PATCH 93/96] Moving single marlin alongside fused marlin --- tests/kernels/test_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 78 ------------------- .../layers/fused_moe/fused_moe_marlin.py | 77 ++++++++++++++++++ 3 files changed, 79 insertions(+), 80 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b53f578988214..41d5478c2e5d0 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,8 +10,8 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe, single_marlin_moe -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin, single_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize, ) from vllm.model_executor.models.mixtral import MixtralMoE diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9ae5859c4da0c..797bbfe5c71c1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -627,81 +627,3 @@ def fused_moe( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale) - - -def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - rand_perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, -) -> torch.Tensor: - """ - This function computes a Marlin MoE MMM using weights w - and top-k gating mechanism. It is meant for testing and debugging. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w (torch.Tensor): The first set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w and w2. Defaults to False. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" - assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, K = hidden_states.shape - E = w.shape[0] - N = w.shape[2] // 2 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - # This might not be an optimal config for a single MMM - get_config_func = functools.partial(try_get_optimal_moe_config, - w.shape, - w.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (N // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, - True, False) - - return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py index 7e834e3250f74..48760daba2e41 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -8,6 +8,83 @@ from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + rand_perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, +) -> torch.Tensor: + """ + This function computes a Marlin MoE MMM using weights w + and top-k gating mechanism. It is meant for testing and debugging. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w (torch.Tensor): The first set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w and w2. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // 2 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, + True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + def fused_moe_marlin( hidden_states: torch.Tensor, w1: torch.Tensor, From 4d414252481214c5a66b5defc9245dbc69ae45d8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 19 Aug 2024 10:36:11 -0700 Subject: [PATCH 94/96] Removing unused imports --- vllm/model_executor/layers/fused_moe/layer.py | 3 --- vllm/model_executor/models/mixtral.py | 7 +------ vllm/model_executor/models/mixtral_quant.py | 9 +++------ 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 34bad93f052dc..825236a6e3bf0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,11 +1,8 @@ -import enum from abc import abstractmethod -from enum import Enum from typing import List, Optional, Tuple import torch -from vllm import _custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index f157a21fd27ae..dc46e2d91284d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -22,10 +22,7 @@ # limitations under the License. """Inference-only Mixtral model.""" from typing import Iterable, List, Optional, Tuple -import re import torch -import numpy as np -import torch.nn.functional as F from torch import nn from transformers import MixtralConfig @@ -33,9 +30,7 @@ from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import ( get_pp_group, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - tensor_model_parallel_all_reduce, + get_tensor_model_parallel_world_size ) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 3ff70d222b518..2bbde985ecf0e 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,12 +30,9 @@ from torch import nn from transformers import MixtralConfig -from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers - from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, +from vllm.config import CacheConfig +from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE @@ -49,7 +46,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, DEFAULT_VOCAB_PADDING_SIZE) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput From 4907f43ecefd851deaace753f904cfa4e1c4d368 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Mon, 19 Aug 2024 11:15:59 -0700 Subject: [PATCH 95/96] single marlin moe import --- vllm/model_executor/layers/fused_moe/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index beb94f10a557e..212c3ac846e55 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,4 @@ -from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin -from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import fused_moe_marlin, single_marlin_moe from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON From 315e3b605df96d1edf425d921f7a2ccdbdd93ac5 Mon Sep 17 00:00:00 2001 From: Eliza Wszola Date: Wed, 21 Aug 2024 14:02:49 +0000 Subject: [PATCH 96/96] Unify shard_id to be of str w[1-3] format --- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++---------- .../layers/quantization/experts_int8.py | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 484b22d3ad027..160f6948648af 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -234,9 +234,9 @@ def weight_loader( or "_qzeros" in weight_name): if "w13" in weight_name: shard_size = loaded_weight.size()[-1] - if shard_id == 0: + if shard_id == "w1": param_data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == 2 or shard_id == 1: + elif shard_id == "w3" or shard_id == "w2": param_data[expert_id, :, shard_size:] = loaded_weight else: raise ValueError(f"Invalid shard_id: {shard_id}: " @@ -357,12 +357,11 @@ def forward(self, hidden_states: torch.Tensor, def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: + num_experts: int) -> List[Tuple[str, str, int, str]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] gate_down_up = [ ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name ] - return ([ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) @@ -371,7 +370,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_scale", f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -382,7 +381,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -393,7 +392,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_scales", f"experts.{expert_id}.{weight_name}.scales", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -404,7 +403,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.a2_scale", f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id, + f"a{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -415,7 +414,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_qweight", f"experts.{expert_id}.{weight_name}.qweight", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -426,7 +425,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_g_idx", f"experts.{expert_id}.{weight_name}.g_idx", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ @@ -437,7 +436,7 @@ def make_expert_params_mapping( if weight_name in gate_up else "experts.w2_qzeros", f"experts.{expert_id}.{weight_name}.qzeros", expert_id, - shard_id, + f"w{shard_id + 1}", ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ]) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fef..153bccc303ef1 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -157,7 +157,7 @@ def quantize_and_call_weight_loader(param: torch.nn.Parameter, layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) else: raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + f"Shard id must be in ['w1','w2','w3'] but got {shard_id}") weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)