Skip to content

Commit

Permalink
Cleanup, compressed tensors compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Aug 29, 2024
1 parent f875842 commit d8feb8d
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 233 deletions.
1 change: 0 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
fused_moe_marlin, single_moe_marlin)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)

from vllm.triton_utils import HAS_TRITON

__all__ = [
Expand Down
138 changes: 37 additions & 101 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,90 +225,6 @@ def __init__(
weight_loader=self.weight_loader,
)

# def weight_loader(
# self,
# param: torch.nn.Parameter,
# loaded_weight: torch.Tensor,
# weight_name: str,
# shard_id: str,
# expert_id: int,
# is_quantized: bool = False,
# ):
# param_data = param.data

# if is_quantized:
# if ("_qweight" in weight_name or "_scales" in weight_name
# or "_qzeros" in weight_name):
# if "w13" in weight_name:
# shard_size = loaded_weight.size()[-1]
# if shard_id == "w1":
# param_data[expert_id, :, :shard_size] = loaded_weight
# elif shard_id == "w3" or shard_id == "w2":
# param_data[expert_id, :, shard_size:] = loaded_weight
# else:
# raise ValueError(f"Invalid shard_id: {shard_id}: "
# "must be 0, 1, or 2.")
# elif "w2" in weight_name:
# param_data[expert_id][:] = loaded_weight
# else:
# raise ValueError(f"Invalid weight name: {weight_name}: "
# "must contain 'w13' or 'w2'.")
# elif "_g_idx" in weight_name:
# if "w13" not in weight_name and "w2" not in weight_name:
# raise ValueError(f"Invalid weight name: {weight_name}: "
# "must contain 'w13' or 'w2'.")
# param_data[expert_id] = loaded_weight
# else:
# raise ValueError(f"Invalid weight name: {weight_name}.")
# else:
# if shard_id not in ("w1", "w2", "w3"):
# raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
# f"got {shard_id}.")

# # Special case for fp8 scales.
# if getattr(param, "is_fp8_scale", False):
# self._load_fp8_scale(param.data, loaded_weight, weight_name,
# shard_id, expert_id)
# return

# expert_data = param.data[expert_id]
# tp_rank = get_tensor_model_parallel_rank()

# # If transposed, weight is saved as [input_dim, output_dim]
# # Otherwise, weight is saved as [output_dim, input_dim]
# # Default is not transposed/input dim is dim 1
# input_dim = getattr(param, "input_dim", 1)
# output_dim = getattr(param, "output_dim", 0)

# # Index the loaded weight for tp sharding.
# # down_proj: "RowParallel" so tp sharding on input_dim
# if shard_id == "w2":
# shard_dim = input_dim
# shard_size = expert_data.shape[shard_dim]
# # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# elif shard_id in ("w1", "w3"):
# shard_dim = output_dim
# shard_size = expert_data.shape[output_dim] // 2
# offset = shard_size * tp_rank
# loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)

# # Narrow parameter and load.
# # w1, gate_proj: Load into first logical weight of w13.
# if shard_id == "w1":
# expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# expert_data.copy_(loaded_weight)
# # w3, up_proj: Load into second logical weight of w13.
# elif shard_id == "w3":
# expert_data = expert_data.narrow(shard_dim, shard_size,
# shard_size)
# expert_data.copy_(loaded_weight)
# # w2, down_proj: Load into only logical weight of w2.
# elif shard_id == "w2":
# expert_data.copy_(loaded_weight)
# else:
# raise ValueError(
# f"Expected shard_id w1, w2 or w3 but got {shard_id}")

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
Expand Down Expand Up @@ -395,9 +311,41 @@ def _load_single_value(self, param: torch.nn.Parameter,
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
is_gptq: bool = False,
):
if is_gptq:
param_data = param.data
if ("_qweight" in weight_name or "_scales" in weight_name
or "_qzeros" in weight_name):
if "w13" in weight_name:
shard_size = loaded_weight.size()[-1]
if shard_id == "w1":
param_data[expert_id, :, :shard_size] = loaded_weight
elif shard_id == "w3" or shard_id == "w2":
param_data[expert_id, :, shard_size:] = loaded_weight
else:
raise ValueError(f"Invalid shard_id: {shard_id}: "
"must be 0, 1, or 2.")
elif "w2" in weight_name:
param_data[expert_id][:] = loaded_weight
else:
raise ValueError(f"Invalid weight name: {weight_name}: "
"must contain 'w13' or 'w2'.")
elif "_g_idx" in weight_name:
if "w13" not in weight_name and "w2" not in weight_name:
raise ValueError(f"Invalid weight name: {weight_name}: "
"must contain 'w13' or 'w2'.")
param_data[expert_id] = loaded_weight
else:
raise ValueError(f"Invalid weight name: {weight_name}.")
return

if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
Expand Down Expand Up @@ -550,8 +498,8 @@ def make_expert_params_mapping(
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
"experts.w13_weight_scale"
if weight_name in gate_up else "experts.w2_weight_scale",
f"experts.{expert_id}.{weight_name}.weight_scale",
expert_id,
f"w{shard_id + 1}",
Expand Down Expand Up @@ -625,18 +573,6 @@ def make_expert_params_mapping(
for shard_id, weight_name in enumerate(gate_down_up)
])

# return [
# # (param_name, weight_name, expert_id, shard_id)
# ("experts.w13_" if weight_name
# in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
# f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
# for expert_id in range(num_experts) for shard_id, weight_name in [
# ("w1", ckpt_gate_proj_name),
# ("w2", ckpt_down_proj_name),
# ("w3", ckpt_up_proj_name),
# ]
# ]

def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,18 @@ def apply(self,
from vllm.model_executor.layers.fused_moe.fused_moe_marlin import (
fused_moe_marlin)

return fused_moe_marlin(x,
layer.w13_weight_packed,
layer.w2_weight_packed,
router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
top_k,
renormalize=renormalize,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale)
return fused_moe_marlin(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
top_k,
renormalize=renormalize,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
num_bits=self.num_bits,
)
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
2 * intermediate_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
layer.register_parameter("w13_weight_scale", w13_scale)

w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
layer.register_parameter("w2_weight_scale", w2_scale)

def apply(self,
layer: torch.nn.Module,
Expand Down
12 changes: 8 additions & 4 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -450,9 +452,11 @@ def __init__(
super().__init__()
# TODO keep the fused mixtral_quant codepath around as long as we don't
# support all quant_types
self.use_fused_moe = (quant_config.quant_type == scalar_types.uint4b8
or quant_config.quant_type
== scalar_types.uint8b128)
self.is_compressed = isinstance(quant_config, CompressedTensorsConfig)
self.use_fused_moe = (
self.is_compressed
or quant_config.quant_type == scalar_types.uint4b8
or quant_config.quant_type == scalar_types.uint8b128)
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(self.use_fused_moe,
Expand Down Expand Up @@ -579,7 +583,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
name,
shard_id=shard_id,
expert_id=expert_id,
# is_quantized=True,
is_gptq=not self.is_compressed,
)
break
else:
Expand Down
Loading

0 comments on commit d8feb8d

Please sign in to comment.