From ad3e4f16199a51862d72845f5f7ea53cc92442d2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 13 Aug 2024 15:44:25 -0700 Subject: [PATCH] Update the mixtral to use the better FusedMoE layer (#1081) --- docs/en/model_support.md | 2 +- python/sglang/srt/models/mixtral.py | 308 ++++------------------ python/sglang/srt/models/mixtral_quant.py | 3 - test/srt/test_moe_serving_throughput.py | 2 +- 4 files changed, 57 insertions(+), 258 deletions(-) diff --git a/docs/en/model_support.md b/docs/en/model_support.md index e46e99e85c8..1d720acf5cf 100644 --- a/docs/en/model_support.md +++ b/docs/en/model_support.md @@ -5,7 +5,7 @@ To support a new model in SGLang, you only need to add a single file under [SGLa Another valuable resource is the [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models). vLLM has extensive coverage of models, and SGLang has reused vLLM for most parts of the model implementations. This similarity makes it easy to port many models from vLLM to SGLang. To port a model from vLLM to SGLang, you can compare these two files [SGLang LLaMA Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py) and [vLLM LLaMA Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of PagedAttention with RadixAttention. The other parts are almost identical. Specifically, - - Replace vllm's `Attention` with `RadixAttention`. + - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Remove `Sample`. - Change `forward()` functions, and add `input_metadata`. diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 876c7a09d48..d11f6c95198 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -18,34 +18,25 @@ """Inference-only Mixtral model.""" from typing import Iterable, Optional, Tuple -import numpy as np import torch -import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -from vllm import _custom_ops as ops from vllm.config import CacheConfig -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor @@ -69,216 +60,44 @@ def __init__( hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", ): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, + hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None, + prefix=f"{prefix}.gate", ) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype, - ) - ) - self.w2_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype, - ) - ) - - set_weight_attrs( - self.w13_weight, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_weight, - { - "weight_loader": self.weight_loader, - }, + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", ) - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.w2_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs( - self.w13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - # ACT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - self.a13_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.a2_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - set_weight_attrs( - self.a13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.a2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: - param_data[expert_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - self.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :] - ) - w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :] - ) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - - if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. " - ) - - self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class MixtralAttention(nn.Module): @@ -291,7 +110,7 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -314,7 +133,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -323,12 +141,14 @@ def __init__( self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -365,6 +185,7 @@ def __init__( config: MixtralConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -377,8 +198,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, - sliding_window=config.sliding_window, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -386,6 +207,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -422,6 +244,7 @@ def __init__( self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -431,10 +254,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - # config.num_hidden_layers=16 self.layers = nn.ModuleList( [ - MixtralDecoderLayer(config, i, quant_config=quant_config) + MixtralDecoderLayer( + config, i, quant_config=quant_config, prefix=f"{prefix}.layers" + ) for i in range(config.num_hidden_layers) ] ) @@ -462,6 +286,7 @@ def forward( class MixtralForCausalLM(nn.Module): + def __init__( self, config: MixtralConfig, @@ -471,11 +296,10 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config=quant_config) + self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -496,40 +320,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ( - "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ( - "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ( - "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, ) params_dict = dict(self.named_parameters()) @@ -544,25 +341,35 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader( - param, loaded_weight, weight_name, expert_id=expert_id + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name is None: + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -570,9 +377,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - EntryClass = MixtralForCausalLM diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 115fce1d6d7..b02e925c5a0 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -160,7 +160,6 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -183,7 +182,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -246,7 +244,6 @@ def __init__( num_kv_heads=config.num_key_value_heads, layer_id=layer_id, rope_theta=rope_theta, - sliding_window=config.sliding_window, quant_config=quant_config, ) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index da223e80b9d..48798c5d5f0 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -84,7 +84,7 @@ def test_default_without_radix_cache(self): if os.getenv("SGLANG_IS_IN_CI", "false") == "true": # A100 (PCIE) performance - assert res["output_throughput"] > 950 + assert res["output_throughput"] > 940 def test_default_with_chunked_prefill(self): res = self.run_test(