Skip to content

Commit

Permalink
Move output type conversion to gptq method as well
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 7, 2024
1 parent 8886423 commit ab27497
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def apply(
fused_marlin_moe)

# The input must currently be float16
orig_dtype = x.dtype
x = x.half()

topk_weights, topk_ids = FusedMoE.select_experts(
Expand All @@ -610,4 +611,4 @@ def apply(
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
)
).to(orig_dtype)
3 changes: 1 addition & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ def __init__(self,
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
orig_dtype = hidden_states.dtype
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape).to(orig_dtype)
return final_hidden_states.view(orig_shape)


class MixtralAttention(nn.Module):
Expand Down

0 comments on commit ab27497

Please sign in to comment.