Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 6, 2024
1 parent e7e1d81 commit 9af736f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"update_prefix_dict",
"update_parameter_data",
"register_offload_parameter",
"update_offload_data",
"update_offload_parameter",
"delete_offload_parameter",
"has_offloaded_params",
"disable_hf_hook",
Expand Down Expand Up @@ -131,7 +131,7 @@ def update_parameter_data(
:param new_param_data: tensor to update parameter with
:param param_name: name of module parameter to update
"""
update_offload_data(module, param_name, new_param_data)
update_offload_parameter(module, param_name, new_param_data)


""" Candidates for Upstreaming """
Expand All @@ -151,15 +151,15 @@ def register_offload_parameter(
"""
if has_offloaded_params(module):
module.register_parameter(name, parameter)
update_offload_data(module, name, parameter.data)
update_offload_parameter(module, name, parameter.data)
set_module_tensor_to_device(module, name, "meta")
else:
device = next(module.parameters()).device
parameter = parameter.to(device)
module.register_parameter(name, parameter)


def update_offload_data(
def update_offload_parameter(
module: torch.nn.Module,
name: str,
data: Optional[torch.Tensor],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
disable_hf_hook,
has_offloaded_params,
register_offload_parameter,
update_offload_data,
update_offload_parameter,
)
from tests.testing_utils import requires_accelerate

Expand Down Expand Up @@ -82,20 +82,20 @@ def test_register_offload_parameter():


@requires_accelerate()
def test_update_offload_data():
def test_update_offload_parameter():
from accelerate.hooks import attach_align_device_hook

module = ExampleModule()
param_a = torch.nn.Parameter(torch.tensor(1.0))
param_b = torch.nn.Parameter(torch.tensor(2.0))

# can update modules which are not offloaded
update_offload_data(module, "a", param_a)
update_offload_parameter(module, "a", param_a)
assert module.a == param_a

# can update modules which are offloaded
attach_align_device_hook(module, offload=True, weights_map=module.state_dict())
update_offload_data(module, "b", param_b)
update_offload_parameter(module, "b", param_b)
assert module.b.device == torch.device("meta")
assert module._hf_hook.weights_map["b"] == param_b.data

Expand Down

0 comments on commit 9af736f

Please sign in to comment.