diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index ebe106a5..27349a40 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -28,14 +28,16 @@ ) from safetensors import safe_open from safetensors.torch import save_file +from tqdm import tqdm StateDictType = Union[Dict[str, torch.Tensor], str, Path] TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] + _LOGGER: logging.Logger = logging.getLogger(__name__) -class ConverterNames(Enum): +class ConverterNames(str, Enum): EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor" @@ -73,7 +75,7 @@ def convert_from_safetensors( filepath_: Path = Path(filepath) if not save_dir: - save_dir = "compressed_tensors_model" + save_dir: str = "compressed_tensors_model" save_dir_: Path = Path(save_dir) save_dir_.mkdir(exist_ok=True, parents=True) @@ -81,13 +83,19 @@ def convert_from_safetensors( metadata = {"format": "pt", "source": "Created by SparseML"} # transform and save the state_dict if filepath_.is_dir(): + tqdm.write(f"Converting directory: {filepath}") + tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files") for file in filepath_.glob("*.safetensors"): - _LOGGER.info(f"Loading file: {file}") + tqdm.write(f"Converting file: {file.name}") new_state_dict = {} state_dict: Iterable[StateDictType] = load_safetensors_state_dict( file, by_layers=True ) - for layer_state_dict in state_dict: + layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers") + for layer_state_dict in layer_progress_bar: + layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")] + layer_progress_bar.set_description(f"Converting layer {layer_name}") + layer_progress_bar.update() new_state_dict.update( cls.translate(state_dict=layer_state_dict, **kwargs) ) @@ -126,7 +134,7 @@ def transformations(cls) -> Iterable[TransformationType]: raise NotImplementedError() -@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR.value) +@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR) class ExllamaToCompressedTensorConverter(BaseConverter): """ A converter that applies transformations to the state_dict of a autogptq @@ -183,16 +191,49 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path): :param source_dir: The directory containing the original config.json file :param dest_dir: The directory to save the updated config.json file """ - from sparseml.transformers import SparseAutoConfig + from transformers import AutoConfig - config = SparseAutoConfig.from_pretrained(source_dir) + config = AutoConfig.from_pretrained(source_dir) if hasattr(config, "quantization_config"): _LOGGER.info("Updating quantization config...") - delattr(config, "quantization_config") + quantization_config = config.quantization_config + config.quantization_config = _convert_to_compressed_tensors_config(quantization_config) config.save_pretrained(dest_dir) +def _convert_to_compressed_tensors_config(quantization_config): + """ + Converts the quantization_config attribute from a config.json file + to a dictionary + + :param quantization_config: The quantization_config attribute from a config.json file + :return: The quantization_config as a dictionary + """ + compressed_tensor_config = ... + return compressed_tensor_config + +def layer_count(file_path: str) -> int: + """ + Count the number of layers in a safetensors file + + :param file_path: path to the safetensors file + :return: number of layers in the safetensors file + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + keys = sorted(f.keys()) + + last_layer_name = None + layer_count = 0 + for key in keys: + layer_name = key[:len("model.layers.0")] + if layer_name != last_layer_name: + last_layer_name = layer_name + layer_count += 1 + return layer_count + + + def load_safetensors_state_dict( file_path: str, by_layers: bool = True ) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]: @@ -201,7 +242,7 @@ def load_safetensors_state_dict( :param file_path: path to the safetensors file :param by_layers: if True, return a iterator with dictionary of safetensors - data by layers + data by layers. Default is True :return: Iterator of dictionary of safetensors data or iterator of dictionaries by layers """ @@ -210,11 +251,11 @@ def load_safetensors_state_dict( current_layer = None layer_data = {} for key in sorted(f.keys()): - layer_name, param_name = key.split(".", 1) + layer_name = key[:len("model.layers.0")] if current_layer is None: current_layer = layer_name elif layer_name != current_layer: - yield current_layer, layer_data + yield layer_data current_layer = layer_name layer_data = {} layer_data[key] = f.get_tensor(key) diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py index 3089849c..1f04fc41 100644 --- a/src/compressed_tensors/utils/converters/main.py +++ b/src/compressed_tensors/utils/converters/main.py @@ -15,12 +15,11 @@ from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames - __all__ = ["convert_autogptq_checkpoint"] def convert_autogptq_checkpoint( - old_checkpoint_path, new_checkpoint_path, **kwargs + old_checkpoint_path, new_checkpoint_path ,**kwargs ) -> str: """ Convert an autogptq checkpoint to a compressed tensor checkpoint diff --git a/src/compressed_tensors/utils/converters/transformations.py b/src/compressed_tensors/utils/converters/transformations.py index 86371c8e..fae69137 100644 --- a/src/compressed_tensors/utils/converters/transformations.py +++ b/src/compressed_tensors/utils/converters/transformations.py @@ -29,9 +29,9 @@ def _log_transformation(func): @functools.wraps(func) def wrapper(*args, **kwargs): - _LOGGER.info("Applying transformation: %s", func.__name__.upper()) + _LOGGER.debug("Applying transformation: %s", func.__name__.upper()) return_value = func(*args, **kwargs) - _LOGGER.info("Transformation: %s complete", func.__name__.upper()) + _LOGGER.debug("Transformation: %s complete", func.__name__.upper()) return return_value return wrapper