diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index fe034126..910436eb 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings +from functools import wraps from typing import Any, Callable, Dict, Optional import torch @@ -174,6 +175,7 @@ def decorator(func: Callable[[Any], Any]): if future_name is not None: message += f". Please use {future_name} instead." + @wraps(func) def wrapped(*args, **kwargs): warnings.warn(message, DeprecationWarning, stacklevel=2) return func(*args, **kwargs) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 0df00ec2..0d7b0bbe 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import wraps from typing import Any, Callable, Optional import torch @@ -47,7 +48,12 @@ def check_accelerate(fallback: Any): def decorator(func: Callable[[Any], Any]): if not _has_accelerate: - return lambda *args, **kwargs: fallback + + @wraps(func) + def fallback_fn(*args, **kwargs): + return fallback + + return fallback_fn return func @@ -193,7 +199,7 @@ def update_offload_data( def delete_offload_parameter(module: torch.nn.Module, name: str): """ - Delete a module from a module which may be offloaded + Delete a parameter from a module which may be offloaded :param module: maybe offloaded module :param name: name of parameter being deleted