Skip to content

Commit

Permalink
Update checkpoint handling in QuantizationRecipe to use checkpointer.…
Browse files Browse the repository at this point in the history
…save_checkpoint
  • Loading branch information
Ankur-singh committed Jan 13, 2025
1 parent e79ab8b commit 66b7f09
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions recipes/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict

import torch
Expand Down Expand Up @@ -53,6 +51,11 @@ def __init__(self, cfg: DictConfig) -> None:
training.set_seed(seed=cfg.seed)

def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]:
logger.info(
"Setting safe_serialization to False. TorchAO quantization is compatible "
"only with HuggingFace's non-safetensor serialization and deserialization."
)
checkpointer_cfg.safe_serialization = False
self._checkpointer = config.instantiate(checkpointer_cfg)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
Expand Down Expand Up @@ -95,21 +98,8 @@ def quantize(self, cfg: DictConfig):
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

def save_checkpoint(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]

output_dir = Path(cfg.checkpointer.output_dir)
output_dir.mkdir(exist_ok=True)
checkpoint_file = Path.joinpath(
output_dir, f"{file_name}-{self._quantization_mode}".rstrip("-qat")
).with_suffix(".pt")

torch.save(ckpt_dict, checkpoint_file)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(checkpoint_file) / 1024**3:.2f} GiB "
f"saved to {checkpoint_file}"
)
ckpt_dict = {training.MODEL_KEY: self._model.state_dict()}
self._checkpointer.save_checkpoint(ckpt_dict, epoch=0)


@config.parse
Expand Down

0 comments on commit 66b7f09

Please sign in to comment.