Skip to content

Commit

Permalink
save_torch_state_dict replaces shard_checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dtlzhuangz committed Dec 3, 2024
1 parent 81e0b14 commit 358cb8c
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions python/eetq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from typing import List, Union, Dict
from safetensors.torch import save_file
from typing_extensions import Doc, Annotated
from huggingface_hub import snapshot_download
from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import snapshot_download, save_torch_state_dict
from transformers import (
AutoConfig,
PreTrainedModel,
Expand Down Expand Up @@ -138,26 +137,13 @@ def forward(self, x):

# model_name has no extension, add it when saving state_dict
model_name = "model.safetensors" if safetensors else "pytorch_model.bin"

# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
)

for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(
shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_dir, shard_file))

# save shard index
if index is not None:
with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4))
save_torch_state_dict(
state_dict=self.model.state_dict(),
save_directory=save_dir,
max_shard_size=shard_size,
safe_serialization=safetensors,
force_contiguous=True,
)


@classmethod
Expand Down

0 comments on commit 358cb8c

Please sign in to comment.