From 5340ce84afa13bae53a9843fa6cfe44c7f516dd1 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 25 Oct 2024 08:27:10 -0400 Subject: [PATCH] remove hardcoded type --- vllm/model_executor/models/llama.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d4be3818544d5..3b2fb6d8ad131 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -383,11 +383,6 @@ def unpack_4bit_u8( tmp[step:] = W_q & 0b00001111 return tmp - def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: - # TODO don't hardcode type - return unpack_4bit_u8(loaded_weight, dtype=torch.bfloat16).reshape( - (-1, param.shape[1])).to(torch.uint8) - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -429,8 +424,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: param = params_dict[name] weight_loader = param.weight_loader if self.is_hqq and name.endswith(".W_q"): - weight_loader(param, rescale_hqq_wq(loaded_weight, param), - shard_id) + weight_loader( + param, + unpack_4bit_u8(loaded_weight).reshape( + -1, param.shape[1]), shard_id) elif self.is_hqq and name.endswith((".scale", ".zero")): weight_loader(param, loaded_weight.reshape(-1, param.shape[1]), @@ -465,7 +462,10 @@ def rescale_hqq_wq(loaded_weight: torch.Tensor, param) -> torch.Tensor: weight_loader = getattr(param, "weight_loader", default_weight_loader) if self.is_hqq and name.endswith(".W_q"): - weight_loader(param, rescale_hqq_wq(loaded_weight, param)) + weight_loader( + param, + unpack_4bit_u8(loaded_weight).reshape( + -1, param.shape[1])) elif self.is_hqq and name.endswith((".scale", ".zero")): weight_loader(param, loaded_weight.reshape(-1, param.shape[1]))