diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py index f9f0d37d3..31e08fb81 100644 --- a/examples/quantizing_moe/deepseek_moe_w4a16.py +++ b/examples/quantizing_moe/deepseek_moe_w4a16.py @@ -1,3 +1,4 @@ +from llmcompressor.transformers.tracing.deepseek_v2.configuration_deepseek import DeepseekV2Config import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -18,14 +19,20 @@ device_map = calculate_offload_device_map( MODEL_ID, reserve_for_hessians=True, - num_gpus=2, + num_gpus=1, torch_dtype=torch.bfloat16, trust_remote_code=True, ) #model = AutoModelForCausalLM.from_pretrained( +config = DeepseekV2Config.from_pretrained(MODEL_ID) +config.moe_top_k_activation = True model = TraceableDeepseekV2ForCausalLM.from_pretrained( - MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True + MODEL_ID, + device_map=device_map, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + config=config ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) diff --git a/src/llmcompressor/transformers/tracing/deepseek_v2/configuration_deepseek.py b/src/llmcompressor/transformers/tracing/deepseek_v2/configuration_deepseek.py index 82e0f5d9d..833d86ba5 100644 --- a/src/llmcompressor/transformers/tracing/deepseek_v2/configuration_deepseek.py +++ b/src/llmcompressor/transformers/tracing/deepseek_v2/configuration_deepseek.py @@ -154,6 +154,8 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, + # TRACING: add calibration options + moe_top_k_activation=True, **kwargs, ): self.vocab_size = vocab_size @@ -196,6 +198,7 @@ def __init__( self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.moe_top_k_activation = moe_top_k_activation super().__init__( pad_token_id=pad_token_id, diff --git a/src/llmcompressor/transformers/tracing/deepseek_v2/modeling_deepseek.py b/src/llmcompressor/transformers/tracing/deepseek_v2/modeling_deepseek.py index 4d31770d2..d79f6c2de 100644 --- a/src/llmcompressor/transformers/tracing/deepseek_v2/modeling_deepseek.py +++ b/src/llmcompressor/transformers/tracing/deepseek_v2/modeling_deepseek.py @@ -466,7 +466,7 @@ def forward(self, hidden_states): else: topk_weight = topk_weight * self.routed_scaling_factor ### expert-level computation auxiliary loss - # TRACING: + # TRACING: This only affects the backwards pass, but needed to avoid typing issues on `aux_loss` #if self.training and self.alpha > 0.0: if True: scores_for_aux = scores @@ -573,7 +573,7 @@ def forward(self, hidden_states): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) - # TRACING: + # TRACING: pass activations to all experts #if self.training: if True: hidden_states = hidden_states.repeat_interleave( @@ -585,7 +585,9 @@ def forward(self, hidden_states): y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.to(hidden_states.dtype).view(*orig_shape) y = AddAuxiliaryLoss.apply(y, aux_loss) - else: + # TRACING: Give option to calibrate with top_k experts, as if in inference time + #else: + if self.config.moe_top_k_activation: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity)