Skip to content

Commit

Permalink
sharded files fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Aug 20, 2024
1 parent 3546d74 commit bc0caac
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_safetensors_folder(
model will be searched for in the default TRANSFORMERS_CACHE
:return: local folder containing model data
"""
if isinstance(pretrained_model_name_or_path, list):
# assume sharded files, referencing first file is sufficient
pretrained_model_name_or_path = pretrained_model_name_or_path[0]

if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return os.path.abspath(pretrained_model_name_or_path)
Expand Down

0 comments on commit bc0caac

Please sign in to comment.