Skip to content

Commit

Permalink
Adapt to model format prepared with transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 24, 2024
1 parent 05f1fce commit 9972c88
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 124 deletions.
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def forward(self, input_):
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def forward(self, input_):
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq_marlin": HQQMarlinConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
Expand Down
38 changes: 19 additions & 19 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __repr__(self) -> str:

@classmethod
def get_name(cls) -> str:
return "hqq_marlin"
return "hqq"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
Expand All @@ -60,7 +60,7 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
weight_bits = cls.get_from_keys(config, ["nbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)

Expand Down Expand Up @@ -106,8 +106,8 @@ def create_weights(

weight_loader = extra_weight_attrs.get("weight_loader")

scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)
self.scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)

# Quantized weights
qweight = ModelWeightParameter(data=torch.empty(
Expand All @@ -121,7 +121,7 @@ def create_weights(

zeros = GroupQuantScaleParameter(data=torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
Expand All @@ -130,20 +130,20 @@ def create_weights(

scales = GroupQuantScaleParameter(data=torch.empty(
self.output_size_per_partition,
scales_and_zp_size,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("qweight", qweight)
layer.register_parameter("zeros", zeros)
layer.register_parameter("scales", scales)
layer.register_parameter("W_q", qweight)
layer.register_parameter("zero", zeros)
layer.register_parameter("scale", scales)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
dev = layer.qweight.device
qweight_t = layer.qweight.transpose(1, 0)
dev = layer.W_q.device
qweight_t = layer.W_q.transpose(1, 0)

gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition,
self.output_size_per_partition)
Expand All @@ -156,14 +156,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.output_size_per_partition,
4,
).to(dev)
marlin_s = marlin_permute_scales(layer.scales.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
64).to(dev)
marlin_zp = marlin_permute_scales(layer.zeros.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
64).to(dev)
marlin_s = marlin_permute_scales(
layer.scale.reshape(-1, self.scales_and_zp_size).transpose(1, 0),
self.input_size_per_partition, self.output_size_per_partition,
self.quant_config.group_size).to(dev)
marlin_zp = marlin_permute_scales(
layer.zero.reshape(-1, self.scales_and_zp_size).transpose(1, 0),
self.input_size_per_partition, self.output_size_per_partition,
self.quant_config.group_size).to(dev)

layer.g_idx = marlin_make_empty_g_idx(dev)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,13 @@ def get_quant_config(model_config: ModelConfig,
if model_config.quantization == "gguf":
return quant_cls.from_config({})

if model_config.quantization == "hqq_marlin":
# TODO don't hardcode params
return quant_cls.from_config({"bits": 4, "group_size": 64})
if model_config.quantization == "hqq":
wq_params = (model_config.hf_config.quantization_config["quant_config"]
["weight_quant_params"])
return quant_cls.from_config({
"nbits": wq_params["nbits"],
"group_size": wq_params["group_size"]
})

# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
Expand Down
89 changes: 40 additions & 49 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]

hqq_map = [
(".qweight", "W_q", False),
(".zeros", "zero", True),
(".scales", "scale", True),
]

# unpack function from https://github.com/mobiusml/hqq
def unpack_4bit_u8(
W_q: torch.Tensor,
Expand All @@ -389,48 +383,14 @@ def unpack_4bit_u8(
tmp[step:] = W_q & 0b00001111
return tmp

def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor:
# TODO don't hardcode type
return unpack_4bit_u8(loaded_weight, dtype=torch.bfloat16).reshape(
(-1, param.shape[1])).to(torch.uint8)

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:

if self.is_hqq:
pick_shard_id = None
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
pick_shard_id = shard_id
break
if name.endswith("_proj"):
to_shape = loaded_weight["shape"]
group_size = loaded_weight["group_size"]
for c, k, should_scale in hqq_map:
new_name = name + c
if new_name not in params_dict:
continue
param = params_dict[new_name]
weight_loader = param.weight_loader
if should_scale:
loaded = loaded_weight[k].reshape(
-1, to_shape[1] // group_size)
else:
# TODO should we unpack inside the quantization
# method / kernel?
loaded = unpack_4bit_u8(
loaded_weight[k],
dtype=torch.bfloat16).reshape(to_shape).to(
torch.uint8)

if pick_shard_id is not None:
weight_loader(param, loaded, pick_shard_id)
else:
weight_loader(param, loaded)
else:
name = name + ".weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight["weight"])
continue
for name, loaded_weight in weights:

if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -458,9 +418,27 @@ def unpack_4bit_u8(
if is_pp_missing_parameter(name, self):
continue

# TODO should input/output dim in hqq_marlin.py depend on this?
ignore_hqq = (".axis", ".channel_wise", ".compute_dtype",
".encoded_state_dict", ".group_size", ".nbits",
".offload_meta", ".optimize", ".packing",
".quant_scale", ".quant_zero", ".round_zero",
".shape", ".stores_quant_config",
".unpack_view_dtype", ".view_as_float")
if name.endswith(ignore_hqq) and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if self.is_hqq and name.endswith(".W_q"):
weight_loader(param, rescale_hqq_wq(loaded_weight, param),
shard_id)
elif self.is_hqq and name.endswith((".scale", ".zero")):
weight_loader(param,
loaded_weight.reshape(-1, param.shape[1]),
shard_id)
else:
weight_loader(param, loaded_weight, shard_id)

break
else:
Expand All @@ -475,13 +453,26 @@ def unpack_4bit_u8(
if is_pp_missing_parameter(name, self):
continue

if name not in params_dict:
# TODO should input/output dim in hqq_marlin.py depend on this?
ignore_hqq = (".axis", ".channel_wise", ".compute_dtype",
".encoded_state_dict", ".group_size", ".nbits",
".offload_meta", ".optimize", ".packing",
".quant_scale", ".quant_zero", ".round_zero",
".shape", ".stores_quant_config",
".unpack_view_dtype", ".view_as_float")
if name.endswith(ignore_hqq) and name not in params_dict:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.is_hqq and name.endswith(".W_q"):
weight_loader(param, rescale_hqq_wq(loaded_weight, param))
elif self.is_hqq and name.endswith((".scale", ".zero")):
weight_loader(param,
loaded_weight.reshape(-1, param.shape[1]))
else:
weight_loader(param, loaded_weight)

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down
63 changes: 13 additions & 50 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def _groupby_prefix(

for prefix, group in itertools.groupby(weights_by_parts,
key=lambda x: x[0][0]):
# for parts, weights_data in group:
# print("part: ", parts, weights_data)
yield (
prefix,
# Because maxsplit=1 in weight_name.split(...),
Expand Down Expand Up @@ -135,52 +133,24 @@ def _load_param(
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[str]:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)

if torch.is_tensor(weight_data):
weight_qualname = self._get_qualname(base_prefix, weight_name)

if self._can_skip(weight_qualname):
continue

if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight "
f"'{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue

weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
yield weight_qualname
else:
# TODO remove this when we get a new hqq dataset format
for wn, wd in weight_data.items():

weight_qualname = self._get_qualname(base_prefix, wn)

if self._can_skip(weight_qualname):
continue
if self._can_skip(weight_qualname):
continue

weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, wd)
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")

yield weight_qualname
continue

def _load_one_param(
self,
base_prefix: str,
param: nn.Parameter,
weight_name: str,
weight_data: torch.Tensor,
) -> Iterable[str]:
weight_qualname = self._get_qualname(base_prefix, weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight_data)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)

yield weight_qualname
yield weight_qualname

def _load_module(
self,
Expand Down Expand Up @@ -208,13 +178,6 @@ def _load_module(
if self._can_skip(prefix):
continue

# TODO remove this when we get a new hqq dataset format
if child_prefix == "" and isinstance(child_params, dict):
for _, c_weight in child_params.items():
yield from self._load_param(prefix, c_weight,
child_weights)
continue

if child_prefix in child_modules:
yield from self._load_module(prefix,
child_modules[child_prefix],
Expand Down

0 comments on commit 9972c88

Please sign in to comment.