Skip to content

Commit

Permalink
get compressed pway to work without using torch.cat
Browse files Browse the repository at this point in the history
Signed-off-by: Dipika <dipikasikka1@gmail.com>
  • Loading branch information
dsikka committed Dec 4, 2024
1 parent 34a84a4 commit 1e4e98c
Showing 1 changed file with 23 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,9 @@ def apply_weights(self,
return out

def _decompress_layer_weight(self, layer: torch.nn.Module) -> torch.Tensor:

sparse_24_packed_weight = layer.sparse_24_packed_weight.data
meta = layer.meta.data

split_weights = None
split_meta = None

def _process_split(input_weight, input_meta):
weight_data = {
"sparse_24_packed_weight": input_weight,
Expand All @@ -190,17 +186,29 @@ def _process_split(input_weight, input_meta):
decompress = self.model_compressor.sparsity_compressor.decompress_weight(weight_data)
return decompress

if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(sparse_24_packed_weight, layer.logical_widths)
split_meta = torch.split(meta, layer.logical_widths)

if split_weights:
all_compress = []
for i in range(len(split_weights)):
compress_i = _process_split(split_weights[i], split_meta[i])
all_compress.append(compress_i)

decompressed = torch.cat(all_compress)

if len(layer.logical_widths) > 1:
# to store decompressed weight
weight_to_compress_p = torch.empty(
sparse_24_packed_weight.shape[0],
sparse_24_packed_weight.shape[1] * 2,
dtype=torch.float8_e4m3fn,
device="cuda"
)

shard_offset = 0
for i in range(len(layer.logical_widths)):
if i > 0:
shard_offset += layer.logical_widths[i-1]

sparse_data = sparse_24_packed_weight.narrow(0, shard_offset, layer.logical_widths[i])
meta_data = meta.narrow(0, shard_offset, layer.logical_widths[i])
decompress_i = _process_split(sparse_data, meta_data)

param_data = weight_to_compress_p.narrow(0, shard_offset, layer.logical_widths[i])
param_data.copy_(decompress_i)

decompressed = weight_to_compress_p
else:
decompressed = _process_split(sparse_24_packed_weight, meta)

Expand Down

0 comments on commit 1e4e98c

Please sign in to comment.