Skip to content

Commit

Permalink
remove hardcoded type
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 25, 2024
1 parent 5ef4b80 commit 5340ce8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,6 @@ def unpack_4bit_u8(
tmp[step:] = W_q & 0b00001111
return tmp

def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor:
# TODO don't hardcode type
return unpack_4bit_u8(loaded_weight, dtype=torch.bfloat16).reshape(
(-1, param.shape[1])).to(torch.uint8)

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
Expand Down Expand Up @@ -429,8 +424,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor:
param = params_dict[name]
weight_loader = param.weight_loader
if self.is_hqq and name.endswith(".W_q"):
weight_loader(param, rescale_hqq_wq(loaded_weight, param),
shard_id)
weight_loader(
param,
unpack_4bit_u8(loaded_weight).reshape(
-1, param.shape[1]), shard_id)
elif self.is_hqq and name.endswith((".scale", ".zero")):
weight_loader(param,
loaded_weight.reshape(-1, param.shape[1]),
Expand Down Expand Up @@ -465,7 +462,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if self.is_hqq and name.endswith(".W_q"):
weight_loader(param, rescale_hqq_wq(loaded_weight, param))
weight_loader(
param,
unpack_4bit_u8(loaded_weight).reshape(
-1, param.shape[1]))
elif self.is_hqq and name.endswith((".scale", ".zero")):
weight_loader(param,
loaded_weight.reshape(-1, param.shape[1]))
Expand Down

0 comments on commit 5340ce8

Please sign in to comment.