diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 9cdac782..b31c6f77 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -47,6 +47,10 @@ def get_safetensors_folder( model will be searched for in the default TRANSFORMERS_CACHE :return: local folder containing model data """ + if isinstance(pretrained_model_name_or_path, list): + # assume sharded files, referencing first file is sufficient + pretrained_model_name_or_path = pretrained_model_name_or_path[0] + if os.path.exists(pretrained_model_name_or_path): # argument is a path to a local folder return os.path.abspath(pretrained_model_name_or_path)