Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepseek r1 add static-moe option and remove unnecessary unpad #822

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions scripts/run_example_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ def sample_gsm8k_requests(
"The capital of France is",
"The future of AI is",
]
if args.nprompts > 4:
prompts += random.choices(prompts, k=args.nprompts - 4)
elif args.nprompts < 4:
prompts = prompts[: args.nprompts]
gt = None
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=args.osl)
sampling_params = SamplingParams(temperature=0, max_tokens=args.osl, ignore_eos=True)
model = args.model
if args.tp_size == 1:
llm = LLM(
Expand Down Expand Up @@ -205,4 +209,5 @@ def sample_gsm8k_requests(
print(f"Prompt: {prompt!r}")
print(f"Generated text: {generated_text!r}")
print(f"Ground truth: {gt_i!r}")
print("====================================")
print("====================================")
del llm
1 change: 1 addition & 0 deletions scripts/run_static-online-i1k-o1k-ep8-bestperf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ model="/data/models/DeepSeek-R1/"
tokenizer="/data/models/DeepSeek-R1/"
model_name="DeepSeek-R1"

#VLLM_USE_STATIC_MOE=1 \
HABANA_VISIBLE_DEVICES="ALL" \
VLLM_MOE_N_SLICE=${moe_n_slice} \
VLLM_EP_SIZE=${ep_size} \
Expand Down
88 changes: 60 additions & 28 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
self.moe_n_slice = int(os.environ.get("VLLM_MOE_N_SLICE", 4))
self.use_static_moe = os.environ.get("VLLM_USE_STATIC_MOE", "0") in ["1", "true"]

def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
Expand Down Expand Up @@ -766,9 +767,9 @@ def forward_hpu(
num_experts = layer.w13_weight.shape[0]
n_expert_slice = layer.w13_weight.shape[0] // self.moe_n_slice
assert n_expert_slice * self.moe_n_slice == num_experts

x = x.view(-1, hidden_dim)
if seq_len == 1 and (num_experts == router_logits.size(-1)) and (batch_size * top_k <= 64):
total_num_experts = router_logits.size(-1)
if seq_len == 1 and (num_experts == total_num_experts) and (batch_size * top_k <= 64):
# num_tokens * topk(num_experts per token) is
# less than total number of experts,
# we can safely load less experts weight
Expand All @@ -793,31 +794,50 @@ def forward_hpu(
orig_M_w2 = layer.orig_M_w2.data
orig_N_w2 = layer.orig_N_w2.data
ep_shift = ep_rank * num_experts

if use_partial_experts:
w13_weight_fp8 = layer.w13_weight.index_select(0, topk_ids.view(-1))
w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv.index_select(0, topk_ids.view(-1))
w2_weight_fp8 = layer.w2_weight.index_select(0, topk_ids.view(-1))
w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv.index_select(0, topk_ids.view(-1))

if seq_len > 1:
use_static_moe = False
else:
w13_weight_fp8 = layer.w13_weight
w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv
w2_weight_fp8 = layer.w2_weight
w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv
use_static_moe = self.use_static_moe

w13_weight = dequant_block_fp8_weight_naive(w13_weight_fp8,
w13_weight_scale_inv_fp8,
block_size=self.quant_config.weight_block_size,
dtype=x.dtype,
original_M=orig_M_w13,
original_N=orig_N_w13)
w2_weight = dequant_block_fp8_weight_naive(w2_weight_fp8,
def do_static_moe(x, topk_ids, topk_weights, w13_weight_fp8, w2_weight_fp8, total_num_experts, num_experts, w13_weight_scale_inv_fp8=None, w2_weight_scale_inv_fp8=None):
final_hidden_states = torch.zeros_like(x)
# padded_weights shape is (total_num_experts, num_tokens)
experts_mask = torch.zeros((x.size(0), total_num_experts), dtype=x.dtype, device=x.device)
experts_mask.scatter_(-1, topk_ids, topk_weights)
experts_mask = experts_mask.transpose(0, 1)

for i in range(num_experts):
w13_weight_fp8_slice = w13_weight_fp8[i, ...]
w2_weight_fp8_slice = w2_weight_fp8[i, ...]
w13_scale_fp8_slice = w13_weight_scale_inv_fp8[i, ...]
w2_scale_fp8_slice = w2_weight_scale_inv_fp8[i, ...]

w13_weight = dequant_block_fp8_weight_naive(w13_weight_fp8_slice, w13_scale_fp8_slice, self.quant_config.weight_block_size, x.dtype, orig_M_w13, orig_N_w13)
up_gate_states = torch.matmul(x, w13_weight.transpose(0, 1))
d = up_gate_states.shape[-1] // 2
tmp_states = F.silu(up_gate_states[..., :d]) * up_gate_states[..., d:]

w2_weight = dequant_block_fp8_weight_naive(w2_weight_fp8_slice, w2_scale_fp8_slice, self.quant_config.weight_block_size, x.dtype, orig_M_w13, orig_N_w13)
current_hidden_states = torch.matmul(tmp_states, w2_weight.transpose(0, 1))
padded_weight = experts_mask[i + ep_shift].unsqueeze(1)
final_hidden_states += current_hidden_states * padded_weight

return final_hidden_states

def do_moe(x, topk_ids, topk_weights, w13_weight_fp8, w2_weight_fp8, moe_n_slice, n_expert_slice, w13_weight_scale_inv_fp8=None, w2_weight_scale_inv_fp8=None):
w13_weight = dequant_block_fp8_weight_naive(w13_weight_fp8,
w13_weight_scale_inv_fp8,
block_size=self.quant_config.weight_block_size,
dtype=x.dtype,
original_M=orig_M_w13,
original_N=orig_N_w13)
w2_weight = dequant_block_fp8_weight_naive(w2_weight_fp8,
w2_weight_scale_inv_fp8,
block_size=self.quant_config.weight_block_size,
dtype=x.dtype,
original_M=orig_M_w2,
original_N=orig_N_w2)
def do_moe(x, topk_ids, topk_weights, w13_weight, w2_weight, moe_n_slice, n_expert_slice):
final_hidden_states = torch.zeros_like(x)
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
Expand All @@ -840,16 +860,28 @@ def do_moe(x, topk_ids, topk_weights, w13_weight, w2_weight, moe_n_slice, n_expe
return final_hidden_states

if use_partial_experts:
if w13_weight.size(0) >= 64:
moe_n_slice = 4
w13_weight_fp8 = layer.w13_weight.index_select(0, topk_ids.view(-1))
w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv.index_select(0, topk_ids.view(-1))
w2_weight_fp8 = layer.w2_weight.index_select(0, topk_ids.view(-1))
w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv.index_select(0, topk_ids.view(-1))
new_total_experts = w13_weight_fp8.size(0)
topk_ids_dense = torch.arange(new_total_experts, device=topk_ids.device).view(topk_ids.size(0), topk_ids.size(1))
if use_static_moe:
final_hidden_states = do_static_moe(x, topk_ids_dense, topk_weights, w13_weight_fp8, w2_weight_fp8, new_total_experts, new_total_experts, w13_weight_scale_inv_fp8, w2_weight_scale_inv_fp8)
else:
moe_n_slice = 1
n_expert_slice = w13_weight.size(0) // moe_n_slice
topk_ids_dummy = torch.arange(topk_ids.size(1) * topk_ids.size(0), device=topk_ids.device).view(topk_ids.size(0), topk_ids.size(1))
final_hidden_states = do_moe(x, topk_ids_dummy, topk_weights, w13_weight, w2_weight, moe_n_slice, n_expert_slice)

moe_n_slice = 4 if new_total_experts >= 64 else 1
n_expert_slice = new_total_experts // moe_n_slice
final_hidden_states = do_moe(x, topk_ids_dense, topk_weights, w13_weight_fp8, w2_weight_fp8, moe_n_slice, n_expert_slice, w13_weight_scale_inv_fp8, w2_weight_scale_inv_fp8)
else:
final_hidden_states = do_moe(x, topk_ids, topk_weights, w13_weight, w2_weight, self.moe_n_slice, n_expert_slice)
w13_weight_fp8 = layer.w13_weight
w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv
w2_weight_fp8 = layer.w2_weight
w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv

if use_static_moe:
final_hidden_states = do_static_moe(x, topk_ids, topk_weights, w13_weight_fp8, w2_weight_fp8, total_num_experts, num_experts, w13_weight_scale_inv_fp8, w2_weight_scale_inv_fp8)
else:
final_hidden_states = do_moe(x, topk_ids, topk_weights, w13_weight_fp8, w2_weight_fp8, self.moe_n_slice, n_expert_slice, w13_weight_scale_inv_fp8, w2_weight_scale_inv_fp8)


return final_hidden_states.view(-1, x.shape[1])
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def pad_weight(weight, block_size):

def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
# if weight.shape[-2] == original_M and weight.shape[-1] == original_N:
# return weight
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
Expand All @@ -94,7 +94,8 @@ def pad_block_fp8_weight_naive(weight, weight_scale, block_size):


def dequant_block_fp8_weight_naive(weight, weight_scale, block_size, dtype, original_M, original_N, do_unpad=False):

if weight_scale is None:
return weight
assert len(block_size) == 2

weight_shape_len = len(weight.shape)
Expand Down Expand Up @@ -139,8 +140,8 @@ def apply_block_fp8_linear_hpu(
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
original_M = original_M.data
original_N = original_N.data
original_M = original_M.data.item()
original_N = original_N.data.item()
output_shape = [*input.shape[:-1], original_M]
dequant_weight = dequant_block_fp8_weight_naive(weight, weight_scale, block_size, input_2d.dtype, original_M, original_N, do_unpad)
output = torch.nn.functional.linear(input_2d, dequant_weight, bias=None)
Expand Down
Loading