From 66b7f091e5dcb07e132418556dabc2166fff170d Mon Sep 17 00:00:00 2001 From: Ankur-singh Date: Mon, 13 Jan 2025 05:38:19 -0800 Subject: [PATCH] Update checkpoint handling in QuantizationRecipe to use checkpointer.save_checkpoint --- recipes/quantize.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/recipes/quantize.py b/recipes/quantize.py index bb28d45b87..130a675c6d 100644 --- a/recipes/quantize.py +++ b/recipes/quantize.py @@ -3,10 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time -from pathlib import Path from typing import Any, Dict import torch @@ -53,6 +51,11 @@ def __init__(self, cfg: DictConfig) -> None: training.set_seed(seed=cfg.seed) def load_checkpoint(self, checkpointer_cfg: DictConfig) -> Dict[str, Any]: + logger.info( + "Setting safe_serialization to False. TorchAO quantization is compatible " + "only with HuggingFace's non-safetensor serialization and deserialization." + ) + checkpointer_cfg.safe_serialization = False self._checkpointer = config.instantiate(checkpointer_cfg) checkpoint_dict = self._checkpointer.load_checkpoint() return checkpoint_dict @@ -95,21 +98,8 @@ def quantize(self, cfg: DictConfig): logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") def save_checkpoint(self, cfg: DictConfig): - ckpt_dict = self._model.state_dict() - file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0] - - output_dir = Path(cfg.checkpointer.output_dir) - output_dir.mkdir(exist_ok=True) - checkpoint_file = Path.joinpath( - output_dir, f"{file_name}-{self._quantization_mode}".rstrip("-qat") - ).with_suffix(".pt") - - torch.save(ckpt_dict, checkpoint_file) - logger.info( - "Model checkpoint of size " - f"{os.path.getsize(checkpoint_file) / 1024**3:.2f} GiB " - f"saved to {checkpoint_file}" - ) + ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} + self._checkpointer.save_checkpoint(ckpt_dict, epoch=0) @config.parse