Skip to content

Commit

Permalink
clearer delete_from_weights_map
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 19, 2024
1 parent df3e186 commit 665c987
Showing 1 changed file with 30 additions and 47 deletions.
77 changes: 30 additions & 47 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,24 +213,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):

if has_offloaded_params(module):
weights_map = module._hf_hook.weights_map

# for upstreaming, better to add write capabilities to weight map classes first
if isinstance(weights_map, PrefixedDataset):
dataset = weights_map.dataset
prefix = weights_map.prefix
if dataset is not None:
del dataset[f"{prefix}{name}"]

elif isinstance(weights_map, dict):
del weights_map[name]

elif isinstance(weights_map, OffloadedWeightsLoader):
raise NotImplementedError()

elif weights_map is not None:
raise NotImplementedError(
f"Cannot delete parameter from weights_map of type {type(weights_map)}"
)
delete_from_weights_map(weights_map, name)


@check_accelerate(fallback=contextlib.nullcontext())
Expand Down Expand Up @@ -286,35 +269,6 @@ def offload_to_weights_map(
raise NotImplementedError(
"Updating weights_map with disk offloading is not implemented yet"
)
# TODO: below may not be correct and has not been tested
# FUTURE: upstream as OffloadedWeightsLoader.__set_item__
# use_index = "safetensors_file" in next(iter(weights_map.values()))
# if use_index:
# if key not in weights_map:
# weights_map.index[key] = {
# "safetensors_file": ???,
# "weight_name": key,
# "dtype": str(value.dtype)
# }

# weight_info = weights_map.index[key]
# file_path = weight_info["safetensors_file"]
# with safetensors.create_file(file_path) as file:
# file.write(value)

# else:
# assert self.save_folder is not None
# weight_file = os.path.join(self.save_folder, f"{key}.dat")
# need_index_update = not os.path.exists(weight_file)
# offload_weight(
# value,
# key,
# weights_map.save_folder,
# weights_map.index
# )

# if need_index_update:
# save_offload_index(weights_map.index, weights_map.save_folder)

elif isinstance(weights_map, dict):
if key in weights_map:
Expand All @@ -332,6 +286,35 @@ def offload_to_weights_map(
)


@check_accelerate(fallback=None)
def delete_from_weights_map(
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
key: str,
):
if isinstance(weights_map, PrefixedDataset):
dataset = weights_map.dataset
key = f"{weights_map.prefix}{key}"
delete_from_weights_map(dataset, key)

elif isinstance(weights_map, OffloadedWeightsLoader):
if len(weights_map.index) <= 0:
delete_from_weights_map(weights_map.state_dict, key)

else:
raise NotImplementedError(
"Delete from weights_map with disk offloading is not implemented yet"
)

elif isinstance(weights_map, dict):
del weights_map[key]

else:
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)


""" Upstreamed Functions """


Expand Down

0 comments on commit 665c987

Please sign in to comment.