diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 452c1a4b40f21..65a9b78a118c3 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -2,7 +2,6 @@ fused_moe_marlin, single_moe_marlin) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) - from vllm.triton_utils import HAS_TRITON __all__ = [ diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a5144b4242601..e54008cecde79 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -225,90 +225,6 @@ def __init__( weight_loader=self.weight_loader, ) - # def weight_loader( - # self, - # param: torch.nn.Parameter, - # loaded_weight: torch.Tensor, - # weight_name: str, - # shard_id: str, - # expert_id: int, - # is_quantized: bool = False, - # ): - # param_data = param.data - - # if is_quantized: - # if ("_qweight" in weight_name or "_scales" in weight_name - # or "_qzeros" in weight_name): - # if "w13" in weight_name: - # shard_size = loaded_weight.size()[-1] - # if shard_id == "w1": - # param_data[expert_id, :, :shard_size] = loaded_weight - # elif shard_id == "w3" or shard_id == "w2": - # param_data[expert_id, :, shard_size:] = loaded_weight - # else: - # raise ValueError(f"Invalid shard_id: {shard_id}: " - # "must be 0, 1, or 2.") - # elif "w2" in weight_name: - # param_data[expert_id][:] = loaded_weight - # else: - # raise ValueError(f"Invalid weight name: {weight_name}: " - # "must contain 'w13' or 'w2'.") - # elif "_g_idx" in weight_name: - # if "w13" not in weight_name and "w2" not in weight_name: - # raise ValueError(f"Invalid weight name: {weight_name}: " - # "must contain 'w13' or 'w2'.") - # param_data[expert_id] = loaded_weight - # else: - # raise ValueError(f"Invalid weight name: {weight_name}.") - # else: - # if shard_id not in ("w1", "w2", "w3"): - # raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - # f"got {shard_id}.") - - # # Special case for fp8 scales. - # if getattr(param, "is_fp8_scale", False): - # self._load_fp8_scale(param.data, loaded_weight, weight_name, - # shard_id, expert_id) - # return - - # expert_data = param.data[expert_id] - # tp_rank = get_tensor_model_parallel_rank() - - # # If transposed, weight is saved as [input_dim, output_dim] - # # Otherwise, weight is saved as [output_dim, input_dim] - # # Default is not transposed/input dim is dim 1 - # input_dim = getattr(param, "input_dim", 1) - # output_dim = getattr(param, "output_dim", 0) - - # # Index the loaded weight for tp sharding. - # # down_proj: "RowParallel" so tp sharding on input_dim - # if shard_id == "w2": - # shard_dim = input_dim - # shard_size = expert_data.shape[shard_dim] - # # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - # elif shard_id in ("w1", "w3"): - # shard_dim = output_dim - # shard_size = expert_data.shape[output_dim] // 2 - # offset = shard_size * tp_rank - # loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) - - # # Narrow parameter and load. - # # w1, gate_proj: Load into first logical weight of w13. - # if shard_id == "w1": - # expert_data = expert_data.narrow(shard_dim, 0, shard_size) - # expert_data.copy_(loaded_weight) - # # w3, up_proj: Load into second logical weight of w13. - # elif shard_id == "w3": - # expert_data = expert_data.narrow(shard_dim, shard_size, - # shard_size) - # expert_data.copy_(loaded_weight) - # # w2, down_proj: Load into only logical weight of w2. - # elif shard_id == "w2": - # expert_data.copy_(loaded_weight) - # else: - # raise ValueError( - # f"Expected shard_id w1, w2 or w3 but got {shard_id}") - def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -395,9 +311,41 @@ def _load_single_value(self, param: torch.nn.Parameter, # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + is_gptq: bool = False, + ): + if is_gptq: + param_data = param.data + if ("_qweight" in weight_name or "_scales" in weight_name + or "_qzeros" in weight_name): + if "w13" in weight_name: + shard_size = loaded_weight.size()[-1] + if shard_id == "w1": + param_data[expert_id, :, :shard_size] = loaded_weight + elif shard_id == "w3" or shard_id == "w2": + param_data[expert_id, :, shard_size:] = loaded_weight + else: + raise ValueError(f"Invalid shard_id: {shard_id}: " + "must be 0, 1, or 2.") + elif "w2" in weight_name: + param_data[expert_id][:] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + elif "_g_idx" in weight_name: + if "w13" not in weight_name and "w2" not in weight_name: + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") + param_data[expert_id] = loaded_weight + else: + raise ValueError(f"Invalid weight name: {weight_name}.") + return if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " @@ -550,8 +498,8 @@ def make_expert_params_mapping( # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) ( - "experts.w13_scale" - if weight_name in gate_up else "experts.w2_scale", + "experts.w13_weight_scale" + if weight_name in gate_up else "experts.w2_weight_scale", f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, f"w{shard_id + 1}", @@ -625,18 +573,6 @@ def make_expert_params_mapping( for shard_id, weight_name in enumerate(gate_down_up) ]) - # return [ - # # (param_name, weight_name, expert_id, shard_id) - # ("experts.w13_" if weight_name - # in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - # f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - # for expert_id in range(num_experts) for shard_id, weight_name in [ - # ("w1", ckpt_gate_proj_name), - # ("w2", ckpt_down_proj_name), - # ("w3", ckpt_up_proj_name), - # ] - # ] - def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 448de19971b41..ba4f719a3f97f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -269,15 +269,18 @@ def apply(self, from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( fused_moe_marlin) - return fused_moe_marlin(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + return fused_moe_marlin( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 153bccc303ef1..6a2ed7704d13f 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -88,13 +88,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, 2 * intermediate_size, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) + layer.register_parameter("w13_weight_scale", w13_scale) w2_scale = torch.nn.Parameter(torch.zeros(num_experts, hidden_size, dtype=torch.float32), requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) + layer.register_parameter("w2_weight_scale", w2_scale) def apply(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 56685c872c447..d3471959a1766 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -42,6 +42,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -450,9 +452,11 @@ def __init__( super().__init__() # TODO keep the fused mixtral_quant codepath around as long as we don't # support all quant_types - self.use_fused_moe = (quant_config.quant_type == scalar_types.uint4b8 - or quant_config.quant_type - == scalar_types.uint8b128) + self.is_compressed = isinstance(quant_config, CompressedTensorsConfig) + self.use_fused_moe = ( + self.is_compressed + or quant_config.quant_type == scalar_types.uint4b8 + or quant_config.quant_type == scalar_types.uint8b128) self.config = config self.lora_config = lora_config self.model = MixtralModel(self.use_fused_moe, @@ -579,7 +583,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name, shard_id=shard_id, expert_id=expert_id, - # is_quantized=True, + is_gptq=not self.is_compressed, ) break else: diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index cd8890b15303b..8bdd52b343175 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,8 +21,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -import logging -import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -36,7 +34,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -52,8 +49,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -logger = logging.getLogger(__name__) - class MixtralMLP(nn.Module): @@ -99,13 +94,10 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.use_fused_moe = use_fused_moe - self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts @@ -121,27 +113,14 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - if self.use_fused_moe: - params_dtype = torch.float16 - self.experts = FusedMoE(num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size) - else: - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, @@ -150,36 +129,31 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if self.use_fused_moe: - ret = self.experts(hidden_states.half(), router_logits) - return ret.bfloat16() - else: - routing_weights = F.softmax(router_logits, - dim=1, - dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) class MixtralAttention(nn.Module): @@ -264,7 +238,6 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -281,7 +254,6 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -322,7 +294,6 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -336,7 +307,6 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, - use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -371,13 +341,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - - # TODO check runs with dtype=float16 - self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, self.use_fused_moe, cache_config, - quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -442,51 +408,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue param = params_dict[name] - - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)