Skip to content

Commit

Permalink
try different shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Oct 14, 2024
1 parent 117eff6 commit ee54bca
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 68 deletions.
3 changes: 0 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def __init__(
if num_kv_heads is None:
num_kv_heads = num_heads

print("QUANT CONFIG:", quant_config)

# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
Expand All @@ -61,7 +59,6 @@ def __init__(
self._v_scale = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
print("QUANT METHOD:", quant_method)
if quant_method is not None:
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
Expand Down
60 changes: 51 additions & 9 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def apply(self,
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""

global_print_ctr = 0

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
Expand All @@ -132,6 +134,13 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

# if UnquantizedLinearMethod.global_print_ctr < 3:
# torch.set_printoptions(edgeitems=128)
# torch.set_printoptions(sci_mode=False)

# print("apply to weight:", layer.weight, layer.weight.shape)
# # # print("and to bias:", bias)
# UnquantizedLinearMethod.global_print_ctr += 1
return F.linear(x, layer.weight, bias)


Expand Down Expand Up @@ -371,12 +380,16 @@ def forward(self, input_):
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
# print("=== ColumnParallelLinear ===")
# print("forward's io:", input_.shape, output.shape)
# print("for input:", input_)
# print("got output:", output)
return output, output_bias

def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
Expand Down Expand Up @@ -431,6 +444,8 @@ def weight_loader(self,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):

# print("weight loader", param.shape, loaded_weight.shape, loaded_shard_id)

# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand Down Expand Up @@ -506,23 +521,31 @@ def weight_loader(self,
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
# print("shard_size1:", shard_size)
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# print(vars(param))
pack_factor = getattr(param, "packed_factor", None)
if pack_factor is None:
pack_factor = param.pack_factor
shard_size = shard_size // pack_factor
shard_offset = shard_offset // pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)

# print("shard_size2:", shard_size)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id

# print("shard_size3:", shard_size)
# print("param data pre:", param_data.shape)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
# print("param data post:", param_data.shape)
start_idx = tp_rank * shard_size
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
Expand All @@ -549,6 +572,8 @@ def weight_loader(self,
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")

# if param_data.shape != loaded_weight.shape:
# print("FAIL", param_data.shape, loaded_weight.shape)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -767,6 +792,8 @@ def weight_loader(self,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):

# print("LOAD", param.shape, loaded_weight.shape, loaded_shard_id)

# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand Down Expand Up @@ -859,8 +886,11 @@ def weight_loader(self,
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
pack_factor = getattr(param, "packed_factor", None)
if pack_factor is None:
pack_factor = param.pack_factor
shard_size = shard_size // pack_factor
shard_offset = shard_offset // pack_factor

# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
Expand Down Expand Up @@ -916,6 +946,7 @@ def weight_loader(self,
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")

# print(param_data.shape, loaded_weight.shape)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -959,6 +990,8 @@ def __init__(self,
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

# print("RPL", input_size, output_size)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results

Expand All @@ -968,8 +1001,6 @@ def __init__(self,
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None

print("rowpar", self.quant_method)

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
Expand Down Expand Up @@ -1014,11 +1045,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
# print("PRE", param_data.shape, loaded_weight.shape, input_dim)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
# print("input_dim:", input_dim, "start_idx:", start_idx,
# "shard_size:", shard_size)
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)

Expand All @@ -1027,6 +1062,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)

# print("POST", param_data.shape, loaded_weight.shape)

assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)

Expand Down Expand Up @@ -1065,12 +1102,17 @@ def forward(self, input_):

output_bias = self.bias if self.skip_bias_add else None

# print("=== RowParallelLinear ===")
# print("forward's io:", input_.shape, output.shape)
# print("for input:", input_)
# print("got output:", output)

return output, output_bias

def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", bias={hasattr(self, 'bias') and self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s
Loading

0 comments on commit ee54bca

Please sign in to comment.