Skip to content

Commit

Permalink
implement offload_to_weights_map with recursive definition
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs committed Dec 16, 2024
1 parent b8ae387 commit 870095e
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 36 deletions.
112 changes: 76 additions & 36 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import contextlib
from functools import wraps
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import torch

Expand All @@ -47,6 +47,12 @@
_has_accelerate = True
except ImportError:
_has_accelerate = False
AlignDevicesHook = None
add_hook_to_module = None
remove_hook_from_module = None
OffloadedWeightsLoader = None
PrefixedDataset = None
set_module_tensor_to_device = None


__all__ = [
Expand Down Expand Up @@ -195,41 +201,7 @@ def update_offload_parameter(
# update offload dict
if has_offloaded_params(module):
weights_map = module._hf_hook.weights_map

# for upstreaming, better to add write capabilities to weight map classes first
if isinstance(weights_map, PrefixedDataset):
dataset = getattr(weights_map, "dataset", None)
if dataset is not None:
prefix = module._hf_hook.weights_map.prefix
key = f"{prefix}{name}"

offload_device = (
dataset[key].device
if key in dataset
else next(iter(dataset.values())).device
)

if isinstance(dataset, OffloadedWeightsLoader):
dataset.state_dict[key] = data.to(device=offload_device)
else:
dataset[key] = data.to(device=offload_device)

elif isinstance(weights_map, dict):
offload_device = (
weights_map[name].device
if name in weights_map
else next(iter(weights_map.values())).device
)
weights_map[name] = data.to(device=offload_device)

elif isinstance(weights_map, OffloadedWeightsLoader):
raise NotImplementedError()

else:
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)
offload_to_weights_map(weights_map, name, data)


def delete_offload_parameter(module: torch.nn.Module, name: str):
Expand Down Expand Up @@ -285,6 +257,74 @@ def collect_hooks(module):
add_hook_to_module(submodule, hook)


def offload_to_weights_map(
weights_map: Union[PrefixedDataset, dict, OffloadedWeightsLoader],
key: str,
value: torch.Tensor,
default_device: torch.device = torch.device("cpu"),
):
if isinstance(weights_map, PrefixedDataset):
dataset = weights_map.dataset
key = f"{weights_map.prefix}{key}"
offload_to_weights_map(dataset, key, value)

elif isinstance(weights_map, OffloadedWeightsLoader):
if key not in weights_map.all_keys:
weights_map.all_keys.append(key)

if len(weights_map.index) <= 0:
offload_to_weights_map(weights_map.state_dict, key, value)

else:
raise NotImplementedError(
"Updating weights_map with disk offloading is not implemented yet"
)
# TODO: below may not be correct and has not been tested
# FUTURE: upstream as OffloadedWeightsLoader.__set_item__
# use_index = "safetensors_file" in next(iter(weights_map.values()))
# if use_index:
# if key not in weights_map:
# weights_map.index[key] = {
# "safetensors_file": ???,
# "weight_name": key,
# "dtype": str(value.dtype)
# }

# weight_info = weights_map.index[key]
# file_path = weight_info["safetensors_file"]
# with safetensors.create_file(file_path) as file:
# file.write(value)

# else:
# assert self.save_folder is not None
# weight_file = os.path.join(self.save_folder, f"{key}.dat")
# need_index_update = not os.path.exists(weight_file)
# offload_weight(
# value,
# key,
# weights_map.save_folder,
# weights_map.index
# )

# if need_index_update:
# save_offload_index(weights_map.index, weights_map.save_folder)

elif isinstance(weights_map, dict):
if key in weights_map:
offload_device = weights_map[key].device
else:
tens = next(iter(weights_map.values()), None)
offload_device = tens.device if tens is not None else default_device

weights_map[key] = value.to(device=offload_device)

else:
raise NotImplementedError(
"Updating offload data not implemented for weights_map of type "
f"{type(weights_map)}"
)


""" Upstreamed Functions """


Expand Down
51 changes: 51 additions & 0 deletions tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
register_offload_parameter,
update_offload_parameter,
)
from compressed_tensors.utils.offload import offload_to_weights_map
from tests.testing_utils import requires_accelerate


Expand Down Expand Up @@ -181,3 +182,53 @@ def test_disable_hf_hook_model_recurse():
assert hasattr(module0, "_hf_hook")
assert hasattr(module1, "_hf_hook")
assert hasattr(module2, "_hf_hook")


@requires_accelerate()
def test_offload_to_weights_map():
from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset

name = "name"
old_value = torch.tensor(0.0)
new_value = torch.tensor(1.0)
prefix = "prefix"

# Dict empty
weights_map = {}
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# Dict populated
weights_map = {name: old_value}
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# OffloadedWeightsLoader[Dict] empty
weights_map = OffloadedWeightsLoader({})
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# OffloadedWeightsLoader[Dict] populated
weights_map = OffloadedWeightsLoader({name: old_value})
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# PrefixedDataset[Dict] empty
weights_map = PrefixedDataset({}, prefix)
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# PrefixedDataset[Dict] populated
weights_map = PrefixedDataset({name: old_value}, prefix)
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# PrefixedDataset[OffloadedWeightsLoader[Dict]] empty
weights_map = PrefixedDataset(OffloadedWeightsLoader({}), prefix)
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

# PrefixedDataset[OffloadedWeightsLoader[Dict]] populated
weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix)
offload_to_weights_map(weights_map, name, new_value)
assert weights_map[name] == new_value

0 comments on commit 870095e

Please sign in to comment.