From 9af736f533534f0590bc312e0981e8e1a1ab97cd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Dec 2024 07:15:04 +0000 Subject: [PATCH] rename --- src/compressed_tensors/utils/offload.py | 8 ++++---- tests/test_utils/test_offload.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7b7cc864..ee59933b 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -43,7 +43,7 @@ "update_prefix_dict", "update_parameter_data", "register_offload_parameter", - "update_offload_data", + "update_offload_parameter", "delete_offload_parameter", "has_offloaded_params", "disable_hf_hook", @@ -131,7 +131,7 @@ def update_parameter_data( :param new_param_data: tensor to update parameter with :param param_name: name of module parameter to update """ - update_offload_data(module, param_name, new_param_data) + update_offload_parameter(module, param_name, new_param_data) """ Candidates for Upstreaming """ @@ -151,7 +151,7 @@ def register_offload_parameter( """ if has_offloaded_params(module): module.register_parameter(name, parameter) - update_offload_data(module, name, parameter.data) + update_offload_parameter(module, name, parameter.data) set_module_tensor_to_device(module, name, "meta") else: device = next(module.parameters()).device @@ -159,7 +159,7 @@ def register_offload_parameter( module.register_parameter(name, parameter) -def update_offload_data( +def update_offload_parameter( module: torch.nn.Module, name: str, data: Optional[torch.Tensor], diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index c127dd98..80cb55f7 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -18,7 +18,7 @@ disable_hf_hook, has_offloaded_params, register_offload_parameter, - update_offload_data, + update_offload_parameter, ) from tests.testing_utils import requires_accelerate @@ -82,7 +82,7 @@ def test_register_offload_parameter(): @requires_accelerate() -def test_update_offload_data(): +def test_update_offload_parameter(): from accelerate.hooks import attach_align_device_hook module = ExampleModule() @@ -90,12 +90,12 @@ def test_update_offload_data(): param_b = torch.nn.Parameter(torch.tensor(2.0)) # can update modules which are not offloaded - update_offload_data(module, "a", param_a) + update_offload_parameter(module, "a", param_a) assert module.a == param_a # can update modules which are offloaded attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) - update_offload_data(module, "b", param_b) + update_offload_parameter(module, "b", param_b) assert module.b.device == torch.device("meta") assert module._hf_hook.weights_map["b"] == param_b.data