diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index aec8eeb8..620292cf 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -213,24 +213,7 @@ def delete_offload_parameter(module: torch.nn.Module, name: str): if has_offloaded_params(module): weights_map = module._hf_hook.weights_map - - # for upstreaming, better to add write capabilities to weight map classes first - if isinstance(weights_map, PrefixedDataset): - dataset = weights_map.dataset - prefix = weights_map.prefix - if dataset is not None: - del dataset[f"{prefix}{name}"] - - elif isinstance(weights_map, dict): - del weights_map[name] - - elif isinstance(weights_map, OffloadedWeightsLoader): - raise NotImplementedError() - - elif weights_map is not None: - raise NotImplementedError( - f"Cannot delete parameter from weights_map of type {type(weights_map)}" - ) + delete_from_weights_map(weights_map, name) @check_accelerate(fallback=contextlib.nullcontext()) @@ -286,35 +269,6 @@ def offload_to_weights_map( raise NotImplementedError( "Updating weights_map with disk offloading is not implemented yet" ) - # TODO: below may not be correct and has not been tested - # FUTURE: upstream as OffloadedWeightsLoader.__set_item__ - # use_index = "safetensors_file" in next(iter(weights_map.values())) - # if use_index: - # if key not in weights_map: - # weights_map.index[key] = { - # "safetensors_file": ???, - # "weight_name": key, - # "dtype": str(value.dtype) - # } - - # weight_info = weights_map.index[key] - # file_path = weight_info["safetensors_file"] - # with safetensors.create_file(file_path) as file: - # file.write(value) - - # else: - # assert self.save_folder is not None - # weight_file = os.path.join(self.save_folder, f"{key}.dat") - # need_index_update = not os.path.exists(weight_file) - # offload_weight( - # value, - # key, - # weights_map.save_folder, - # weights_map.index - # ) - - # if need_index_update: - # save_offload_index(weights_map.index, weights_map.save_folder) elif isinstance(weights_map, dict): if key in weights_map: @@ -332,6 +286,35 @@ def offload_to_weights_map( ) +@check_accelerate(fallback=None) +def delete_from_weights_map( + weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader], + key: str, +): + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + key = f"{weights_map.prefix}{key}" + delete_from_weights_map(dataset, key) + + elif isinstance(weights_map, OffloadedWeightsLoader): + if len(weights_map.index) <= 0: + delete_from_weights_map(weights_map.state_dict, key) + + else: + raise NotImplementedError( + "Delete from weights_map with disk offloading is not implemented yet" + ) + + elif isinstance(weights_map, dict): + del weights_map[key] + + else: + raise NotImplementedError( + "Updating offload data not implemented for weights_map of type " + f"{type(weights_map)}" + ) + + """ Upstreamed Functions """