diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index 2c49fc21..57898e88 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Callable, Dict, Iterable, Union +from typing import Callable, Dict, Iterable, Iterator, Tuple, Union import torch from compressed_tensors.registry.registry import RegistryMixin @@ -77,22 +77,34 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: save_dir_.mkdir(exist_ok=True, parents=True) metadata = {"format": "pt", "source": "Created by SparseML"} - # transform and save the state_dict if filepath_.is_dir(): for file in filepath_.glob("*.safetensors"): _LOGGER.info(f"Loading file: {file}") - state_dict: StateDictType = load_safetensors_state_dict(file) - new_state_dict = cls.translate(state_dict=state_dict) - save_file( - new_state_dict, filename=save_dir_ / file.name, metadata=metadata + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True ) + for layer_state_dict in state_dict: + new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + + if new_state_dict: + save_file( + new_state_dict, + filename=save_dir_ / file.name, + metadata=metadata, + ) _copy_non_safetensor_files_(filepath_, save_dir_) _update_quantization_config(filepath_, save_dir_) elif filepath_.is_file(): - state_dict: StateDictType = load_safetensors_state_dict(filepath) - new_state_dict = cls.translate(state_dict=state_dict) + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True + ) + for layer_state_dict in state_dict: + new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + save_file( new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata ) @@ -177,12 +189,32 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path): config.save_pretrained(dest_dir) -def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]: +def load_safetensors_state_dict( + file_path: str, by_layers: bool = True +) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]: """ Load a safetensors file from disk :param file_path: path to the safetensors file - :return: dictionary of safetensors data + :param by_layers: if True, return a iterator with dictionary of safetensors + data by layers + :return: Iterator of dictionary of safetensors data or iterator of + dictionaries by layers """ with safe_open(file_path, framework="pt", device="cpu") as f: - return {key: f.get_tensor(key) for key in f.keys()} + if by_layers: + current_layer = None + layer_data = {} + for key in sorted(f.keys()): + layer_name, param_name = key.split(".", 1) + if current_layer is None: + current_layer = layer_name + elif layer_name != current_layer: + yield current_layer, layer_data + current_layer = layer_name + layer_data = {} + layer_data[key] = f.get_tensor(key) + if layer_data: + yield layer_data + else: + yield {key: f.get_tensor(key) for key in f.keys()}