Skip to content

Commit

Permalink
add moe_top_k_activation
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
kylesayrs committed Jan 14, 2025
1 parent 74ac982 commit 2d4791c
Showing 3 changed files with 17 additions and 5 deletions.
11 changes: 9 additions & 2 deletions examples/quantizing_moe/deepseek_moe_w4a16.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from llmcompressor.transformers.tracing.deepseek_v2.configuration_deepseek import DeepseekV2Config
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -18,14 +19,20 @@
device_map = calculate_offload_device_map(
MODEL_ID,
reserve_for_hessians=True,
num_gpus=2,
num_gpus=1,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)

#model = AutoModelForCausalLM.from_pretrained(
config = DeepseekV2Config.from_pretrained(MODEL_ID)
config.moe_top_k_activation = True
model = TraceableDeepseekV2ForCausalLM.from_pretrained(
MODEL_ID, device_map=device_map, torch_dtype=torch.bfloat16, trust_remote_code=True
MODEL_ID,
device_map=device_map,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
config=config
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

Original file line number Diff line number Diff line change
@@ -154,6 +154,8 @@ def __init__(
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
# TRACING: add calibration options
moe_top_k_activation=True,
**kwargs,
):
self.vocab_size = vocab_size
@@ -196,6 +198,7 @@ def __init__(
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.moe_top_k_activation = moe_top_k_activation

super().__init__(
pad_token_id=pad_token_id,
Original file line number Diff line number Diff line change
@@ -466,7 +466,7 @@ def forward(self, hidden_states):
else:
topk_weight = topk_weight * self.routed_scaling_factor
### expert-level computation auxiliary loss
# TRACING:
# TRACING: This only affects the backwards pass, but needed to avoid typing issues on `aux_loss`
#if self.training and self.alpha > 0.0:
if True:
scores_for_aux = scores
@@ -573,7 +573,7 @@ def forward(self, hidden_states):
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
# TRACING:
# TRACING: pass activations to all experts
#if self.training:
if True:
hidden_states = hidden_states.repeat_interleave(
@@ -585,7 +585,9 @@ def forward(self, hidden_states):
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.to(hidden_states.dtype).view(*orig_shape)
y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
# TRACING: Give option to calibrate with top_k experts, as if in inference time
#else:
if self.config.moe_top_k_activation:
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)

0 comments on commit 2d4791c

Please sign in to comment.