From 98798897d92d04911806636f4e6149b511a18d53 Mon Sep 17 00:00:00 2001 From: maxreciprocate <56548574+maxreciprocate@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:24:16 +0200 Subject: [PATCH] fix(modeling_base): loading sharded safetensors --- tests/test_peft.py | 2 +- trlx/models/modeling_base.py | 71 ++++++++++++++++++------------------ 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/tests/test_peft.py b/tests/test_peft.py index ffe6f8bcb..37dd164af 100644 --- a/tests/test_peft.py +++ b/tests/test_peft.py @@ -10,7 +10,7 @@ import torch import transformers from peft import get_peft_config, get_peft_model -from peft.utils.config import PeftType, TaskType +from peft.utils import PeftType, TaskType from transformers import AutoConfig, AutoModelForCausalLM from trlx.data.configs import TokenizerConfig diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 26aa7a876..330e28495 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -24,7 +24,7 @@ import torch import torch.nn as nn import transformers -from huggingface_hub import hf_hub_download +import huggingface_hub import trlx.utils.logging as logging from trlx.utils import is_peft_available @@ -155,8 +155,10 @@ def from_pretrained( # noqa: max-complexity call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific instance of the wrapped model. + NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. """ + if kwargs is not None: peft_from_pretrained_kwargs = kwargs.pop("peft_from_pretrained_kwargs", {}) peft_int8_kwargs = kwargs.pop("peft_int8_kwargs", {}) @@ -273,42 +275,41 @@ def from_pretrained( # noqa: max-complexity model = cls(base_model, **wrapped_model_kwargs) if isinstance(pretrained_model_name_or_path, str): - filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") - sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") - is_sharded = False - - if not os.path.exists(filename): + if not os.path.exists(pretrained_model_name_or_path): try: - filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin", revision=revision) - # Sharded - except Exception: - if os.path.exists(sharded_index_filename): - index_file_name = sharded_index_filename - else: - index_file_name = hf_hub_download( - pretrained_model_name_or_path, - "pytorch_model.bin.index.json", - revision=revision, - ) - with open(index_file_name, "r") as f: - index = json.load(f) - - # Load all weights from the shards - files_to_download = set(index["weight_map"].values()) - is_sharded = True - - if is_sharded: - # Merge each shard into a state dict - # TODO: Optimize this to avoid wasting RAM - state_dict = {} - for shard_file in files_to_download: - filename = os.path.join(pretrained_model_name_or_path, shard_file) - # Download if shard file doesn't exist locally - if not os.path.exists(filename): - filename = hf_hub_download(pretrained_model_name_or_path, shard_file, revision=revision) - state_dict.update(torch.load(filename, map_location="cpu")) + pretrained_model_name_or_path = huggingface_hub.snapshot_download(pretrained_model_name_or_path, revision=revision) + except huggingface_hub.utils._errors.RepositoryNotFoundError: + raise ValueError("Invalid `pretrained_model_name_or_path`. It should be a local path or a repository name.") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + sharded_safetensors_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + + if os.path.exists(sharded_index_filename): + with open(shared_index_filename, "r") as f: + index = json.load(f) + shards = set(index["weight_map"].values()) + elif os.path.exists(sharded_safetensors_index_filename): + with open(sharded_safetensors_index_filename, "r") as f: + index = json.load(f) + shards = set(index["weight_map"].values()) else: - state_dict = torch.load(filename, map_location="cpu") + shard_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + shard_safetensors_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + if os.path.exists(shard_filename): + shards = [shard_filename] + elif os.path.exists(shard_safetensors_filename): + shards = [shard_safetensors_filename] + + state_dict = {} + for shard in shards: + if shard.endswith(".safetensors"): + import safetensors + with safetensors.safe_open(os.path.join(pretrained_model_name_or_path, shard), framework="pt") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + else: + state_dict.update(torch.load(shard, map_location="cpu")) + else: state_dict = pretrained_model_name_or_path.state_dict()