Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional fixes for HFQuantizer compatibility #136

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def parse_quantization_config(
if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
elif isinstance(compression_config, dict) and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already merged

QUANTIZATION_CONFIG_NAME in compression_config
):
return compression_config[QUANTIZATION_CONFIG_NAME]

if QUANTIZATION_CONFIG_NAME in compression_config:
# for loaded HFQuantizer config from dict
Expand Down
16 changes: 14 additions & 2 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import torch
from torch.nn import Module
from torch.nn import Module, Parameter


__all__ = [
Expand Down Expand Up @@ -106,7 +106,19 @@ def update_parameter_data(
raise ValueError("Attempted to update uninitialized parameter")

dtype = parameter.dtype
parameter.data = new_param_data.to(device).to(dtype)
try:
parameter.data = new_param_data.to(device).to(dtype)
except RuntimeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better handled by #193

# exception may occur when trying to overwrite meta device, overriding
# parameter directly
setattr(
module,
param_name,
Parameter(
data=new_param_data.to(device).to(dtype),
requires_grad=parameter.requires_grad,
),
)

if offloaded:
prefix_dict = module._hf_hook.weights_map.dataset
Expand Down
4 changes: 4 additions & 0 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_safetensors_folder(
model will be searched for in the default TRANSFORMERS_CACHE
:return: local folder containing model data
"""
if isinstance(pretrained_model_name_or_path, list):
# assume sharded files, referencing first file is sufficient
pretrained_model_name_or_path = pretrained_model_name_or_path[0]

if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return os.path.abspath(pretrained_model_name_or_path)
Expand Down
Loading