diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 43ec43ce..3d92b9ba 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -135,8 +135,9 @@ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ if not has_offloaded_params(module): raise ValueError("Prefix dict is only applicable to offloaded modules") - prefix_dict = module._hf_hook.weights_map - prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + weights_map = module._hf_hook.weights_map + offload_to_weights_map(weights_map, key, data) def update_parameter_data(