From b3de9b437776c4015a1d4f0d4a2a90bcaa955ce4 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 13 Jun 2024 13:13:04 +0000 Subject: [PATCH 1/6] fix some typos --- src/compressed_tensors/quantization/lifecycle/forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 263796eb..739dac0c 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -190,7 +190,7 @@ def _process_quantization( if columns >= group_size: if columns % group_size != 0: raise ValueError( - "tesnor column shape must be divisble " + "tensor column shape must be divisible " f"by the given group_size {group_size}" ) for i in range(ceil(columns / group_size)): From f0da4c8afffd4ee9bcdf9d351fcf2a95940c74a0 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 13 Jun 2024 14:03:42 +0000 Subject: [PATCH 2/6] Add Converters Scafolding --- src/compressed_tensors/utils/__init__.py | 1 + .../utils/converters/__init__.py | 17 ++ .../utils/converters/converters.py | 188 +++++++++++++++ .../utils/converters/main.py | 37 +++ .../utils/converters/transformations.py | 224 ++++++++++++++++++ tests/test_utils/converters/__init__.py | 13 + tests/test_utils/converters/test_imports.py | 57 +++++ 7 files changed, 537 insertions(+) create mode 100644 src/compressed_tensors/utils/converters/__init__.py create mode 100644 src/compressed_tensors/utils/converters/converters.py create mode 100644 src/compressed_tensors/utils/converters/main.py create mode 100644 src/compressed_tensors/utils/converters/transformations.py create mode 100644 tests/test_utils/converters/__init__.py create mode 100644 tests/test_utils/converters/test_imports.py diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index 5bc0fec2..f0b267b7 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. # flake8: noqa +from .converters import * from .safetensors_load import * diff --git a/src/compressed_tensors/utils/converters/__init__.py b/src/compressed_tensors/utils/converters/__init__.py new file mode 100644 index 00000000..540036b0 --- /dev/null +++ b/src/compressed_tensors/utils/converters/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .main import * diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py new file mode 100644 index 00000000..2c49fc21 --- /dev/null +++ b/src/compressed_tensors/utils/converters/converters.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import shutil +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, Iterable, Union + +import torch +from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.utils.converters.transformations import ( + transform_autogptq_weights_and_reshape_tensors, + transform_exllama_names, +) +from safetensors import safe_open +from safetensors.torch import save_file + + +StateDictType = Union[Dict[str, torch.Tensor], str, Path] +TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class ConverterNames(Enum): + EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor" + + +class BaseConverter(ABC, RegistryMixin): + @classmethod + def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: + """ + Applies transformations to the state_dict + + :param state_dict: The state_dict to apply transformations to + :param kwargs: Additional arguments to pass to the transformations + :return: The transformed state_dict + """ + _LOGGER.info("Applying transformations...") + new_state_dict = copy.copy(state_dict) + for transformation in cls.transformations(): + new_state_dict = transformation(new_state_dict, **kwargs) + return new_state_dict + + @classmethod + def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: + """ + Convert a .safetensors file or directory of .safetensors files, applying + transformations to the state_dict and saving the new state_dict to a new + directory + + :param filepath: The file path to the .safetensors file or directory + containing .safetensors files to convert + :param save_dir: The directory to save the converted state_dict to + :return: The directory where the converted state_dict was saved + """ + _validate_safetensors_file_path(filepath) + + filepath_: Path = Path(filepath) + if not save_dir: + save_dir = "compressed_tensors_model" + + save_dir_: Path = Path(save_dir) + save_dir_.mkdir(exist_ok=True, parents=True) + + metadata = {"format": "pt", "source": "Created by SparseML"} + + # transform and save the state_dict + if filepath_.is_dir(): + for file in filepath_.glob("*.safetensors"): + _LOGGER.info(f"Loading file: {file}") + state_dict: StateDictType = load_safetensors_state_dict(file) + new_state_dict = cls.translate(state_dict=state_dict) + save_file( + new_state_dict, filename=save_dir_ / file.name, metadata=metadata + ) + _copy_non_safetensor_files_(filepath_, save_dir_) + _update_quantization_config(filepath_, save_dir_) + + elif filepath_.is_file(): + state_dict: StateDictType = load_safetensors_state_dict(filepath) + new_state_dict = cls.translate(state_dict=state_dict) + save_file( + new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata + ) + + return str(save_dir_) + + @classmethod + @abstractmethod + def transformations(cls) -> Iterable[TransformationType]: + """ + Returns an iterable of transformations that are applied in the converter, + each transformation should be a callable that takes a state_dict and returns + a transformed state_dict + """ + raise NotImplementedError() + + +@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR.value) +class ExllamaToCompressedTensorConverter(BaseConverter): + """ + A converter that applies transformations to the state_dict of a autogptq + quantized model to convert it to a compressed tensor model, which can be + loaded by the SparseAutoModel classes + """ + + @classmethod + def transformations(cls): + return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names) + + +def _validate_safetensors_file_path(filepath: str): + """ + Given a file path, it is valid if: + - The file exists + - The file is either a single .safetensors file or a + directory containing .safetensors files + + :param filepath: A string file path to validate + """ + + filepath_: Path = Path(filepath) + + if not filepath_.exists(): + raise FileNotFoundError(f"File not found: {filepath}") + + if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): + raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") + + if filepath_.is_file() and not filepath_.suffix == ".safetensors": + raise ValueError(f"File must be a .safetensors file: {filepath}") + + +def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path): + """ + A helper function to copy all auxillary files in a directory that are + not .safetensors files, for example (config.json, recipe.yaml, ...) + + :param source_dir: The directory to copy files from + :param dest_dir: The directory to copy files to + """ + for file in source_dir.glob("*"): + if file.suffix != ".safetensors": + _LOGGER.info(f"Copying file: {file} to {dest_dir}") + shutil.copy(file, dest_dir / file.name) + + +def _update_quantization_config(source_dir: Path, dest_dir: Path): + """ + Updates config.json file in the destination directory by removing the + quantization_config attribute + + :param source_dir: The directory containing the original config.json file + :param dest_dir: The directory to save the updated config.json file + """ + from sparseml.transformers import SparseAutoConfig + + config = SparseAutoConfig.from_pretrained(source_dir) + + if hasattr(config, "quantization_config"): + _LOGGER.info("Updating quantization config...") + delattr(config, "quantization_config") + config.save_pretrained(dest_dir) + + +def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]: + """ + Load a safetensors file from disk + + :param file_path: path to the safetensors file + :return: dictionary of safetensors data + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py new file mode 100644 index 00000000..122400a3 --- /dev/null +++ b/src/compressed_tensors/utils/converters/main.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames + + +__all__ = ["convert_autogptq_checkpoint"] + + +def convert_autogptq_checkpoint(old_checkpoint_path, new_checkpoint_path) -> str: + """ + Convert an autogptq checkpoint to a compressed tensor checkpoint + + :param old_checkpoint_path: the path to the autogptq checkpoint + :param new_checkpoint_path: the path to save the converted compressed + tensor checkpoint + :return: the path to the new checkpoint + """ + converter: BaseConverter = BaseConverter.load_from_registry( + ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR + ) + checkpoint_path = converter.convert_from_safetensors( + old_checkpoint_path, new_checkpoint_path + ) + return checkpoint_path diff --git a/src/compressed_tensors/utils/converters/transformations.py b/src/compressed_tensors/utils/converters/transformations.py new file mode 100644 index 00000000..86371c8e --- /dev/null +++ b/src/compressed_tensors/utils/converters/transformations.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa: F821 + +import functools +import logging +from typing import Dict + +import numpy +import numpy as np +import torch +from torch import Tensor + + +_LOGGER = logging.getLogger(__name__) + + +def _log_transformation(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _LOGGER.info("Applying transformation: %s", func.__name__.upper()) + return_value = func(*args, **kwargs) + _LOGGER.info("Transformation: %s complete", func.__name__.upper()) + return return_value + + return wrapper + + +def is_gptq_quantization_target(key: str) -> bool: + """ + Assumes self_attn and mlp are the only quantization targets + in model layers of the state_dict. + :param key: The key of the state_dict + :return: True if the key is a quantization target, False otherwise + """ + return "model.layers" in key and ("self_attn" in key or "mlp" in key) + + +@_log_transformation +def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Transforms the exallama state_dict keys to be compatible with + SparseAutoModel classes. + + The renames include: + - scales -> weight_scale + - qzeros -> weight_zero_point + - qweight -> weight + + Note: does not transforms the actual tensor values + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: Targets only the weights of the self_attn and mlp nodes + :param state_dict: The quantized state_dict to be transformed + :return: The transformed state_dict + """ + + name_map: Dict[str, str] = { + ".scales": ".weight_scale", + ".qzeros": ".weight_zero_point", + ".qweight": ".weight", + } + + updated_state_dict = {} + for key, tensor in state_dict.items(): + if any(key.endswith(target_suffix := suffix) for suffix in name_map): + updated_key = key.replace(target_suffix, name_map[target_suffix]) + updated_state_dict[updated_key] = tensor + else: + updated_state_dict[key] = tensor + return updated_state_dict + + +@_log_transformation +def transform_autogptq_weights_and_reshape_tensors( + state_dict: Dict[str, Tensor] +) -> Dict[str, Tensor]: + """ + Tranforms weights into their required shapes and types for Exllama + to CompressedTensors conversion + + The transformations include: + - Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors + - Squeeze the scales tensor to [x] from [1, x] + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: The state_dict should have the bias and g_idx tensors added + + :param state_dict: The state_dict to be transformed + :return: The transformed state_dict, with repacked and reshaped tensors + """ + + transformed_state_dict: Dict[str, Tensor] = {} + + # auxillary dict to store transformed weights + transformed_weights_dict: Dict[str, Tensor] = {} + + # quantize qweights before scales, and qzeros + # because the ordering in which tensors are fetched + # is not guaranteed by our implementation + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".qweight"): + # quantize the weight tensor + scales = state_dict[key.replace("qweight", "scales")] + qzeros = state_dict[key.replace("qweight", "qzeros")] + g_idx = state_dict[key.replace("qweight", "g_idx")] + + zeros = unpack_zeros(qzeros) + qweight = unpack_int32_into_fp32( + qweight=tensor, + scales=scales, + zeros=zeros, + g_idx=g_idx, + ) + transformed_weights_dict[key] = qweight + + # transform scales + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".scales"): + # scales [1, x] should be reshaped to [x] + scales = tensor.squeeze(0) + transformed_state_dict[key] = scales + else: + transformed_state_dict[key] = tensor + + # overwrite old weights with the new quantized weights + transformed_state_dict.update(transformed_weights_dict) + + # auxillary weights_dict not needed anymore + del transformed_weights_dict + + return transformed_state_dict + + +def unpack_zeros(qzeros): + """ + Unpack the quantized zero points tensor from 32 bit integers into 4 bit integers. + + :param qzeros: The quantized zero points tensor of int32 dtype and shape [1, 8x] + """ + bits = 4 + qzeros = qzeros.numpy().astype(np.uint32) + intzeros = np.zeros( + (qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32 + ) + + i = 0 + col = 0 + while col < intzeros.shape[1]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intzeros.shape[1])): + intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF + i += 32 // bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intzeros = intzeros.astype(np.int32) + intzeros = torch.from_numpy(intzeros) + + return intzeros + + +def unpack_int32_into_fp32( + qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor +) -> Tensor: + """ + Unpack the quantized weight tensor from 32 bit integers into 4 bit integers, + and then dequantize them using the scales, zeros, and g_idx tensors. + + :param qweight: The quantized weight tensor of int32 dtype and shape [x, y] + :param scales: The scales tensor + :param zeros: The zero points tensor + :param g_idx: The group index tensor + :return: The dequantized weight tensor of shape [x, 8y] + """ + bits = 4 + qweight = qweight.numpy().astype(numpy.uint32) + intweight = numpy.zeros( + (qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32 + ) + + i = 0 + row = 0 + while row < intweight.shape[0]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intweight.shape[0])): + intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF + i += 32 // bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intweight = torch.from_numpy(intweight.astype(numpy.int32)) + intweight = intweight.t().contiguous() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + scales = scales.clone().half() + + weight = [] + infeatures = intweight.shape[1] + for idx in range(infeatures): + weight.append( + ( + intweight[:, idx].float() * scales[:, g_idx[idx]] + - scale_zeros[:, g_idx[idx]] + )[:, None] + ) + weight = torch.cat(weight, dim=1) + + return weight diff --git a/tests/test_utils/converters/__init__.py b/tests/test_utils/converters/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/test_utils/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_utils/converters/test_imports.py b/tests/test_utils/converters/test_imports.py new file mode 100644 index 00000000..0a3a6202 --- /dev/null +++ b/tests/test_utils/converters/test_imports.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(scope="module") +def conversion_function(): + from compressed_tensors.utils.converters import convert_autogptq_checkpoint + + return convert_autogptq_checkpoint + + +@pytest.mark.parametrize( + "parent_module, function_name", + [ + ("compressed_tensors.utils", "convert_autogptq_checkpoint"), + ("compressed_tensors.utils.main", "convert_autogptq_checkpoint"), + ], +) +def test_convert_function_is_importable(parent_module, function_name): + import importlib + + module = importlib.import_module(parent_module) + assert hasattr( + module, function_name + ), f"{function_name} is not found in {parent_module}" + + +def test_conversion_function_accepts_correct_arguments(conversion_function): + import inspect + + sig = inspect.signature(conversion_function) + params = sig.parameters + assert ( + "old_checkpoint_path" in params + ), "Function does not accept 'old_checkpoint_path' argument" + assert ( + "new_checkpoint_path" in params + ), "Function does not accept 'new_checkpoint_path' argument" + + # check keyword arguments are also accepted + # (might be needed to configure specific transformations) + assert any( + param.kind == param.VAR_KEYWORD for param in params.values() + ), "Function does not accept **kwargs" From 87b61bade9ad4a23bce236aa4340a8c679b03108 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 13 Jun 2024 14:25:20 +0000 Subject: [PATCH 3/6] Update code to read in state dict layer by layer --- .../utils/converters/converters.py | 54 +++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index 2c49fc21..57898e88 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -18,7 +18,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Callable, Dict, Iterable, Union +from typing import Callable, Dict, Iterable, Iterator, Tuple, Union import torch from compressed_tensors.registry.registry import RegistryMixin @@ -77,22 +77,34 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: save_dir_.mkdir(exist_ok=True, parents=True) metadata = {"format": "pt", "source": "Created by SparseML"} - # transform and save the state_dict if filepath_.is_dir(): for file in filepath_.glob("*.safetensors"): _LOGGER.info(f"Loading file: {file}") - state_dict: StateDictType = load_safetensors_state_dict(file) - new_state_dict = cls.translate(state_dict=state_dict) - save_file( - new_state_dict, filename=save_dir_ / file.name, metadata=metadata + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True ) + for layer_state_dict in state_dict: + new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + + if new_state_dict: + save_file( + new_state_dict, + filename=save_dir_ / file.name, + metadata=metadata, + ) _copy_non_safetensor_files_(filepath_, save_dir_) _update_quantization_config(filepath_, save_dir_) elif filepath_.is_file(): - state_dict: StateDictType = load_safetensors_state_dict(filepath) - new_state_dict = cls.translate(state_dict=state_dict) + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True + ) + for layer_state_dict in state_dict: + new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + save_file( new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata ) @@ -177,12 +189,32 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path): config.save_pretrained(dest_dir) -def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]: +def load_safetensors_state_dict( + file_path: str, by_layers: bool = True +) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]: """ Load a safetensors file from disk :param file_path: path to the safetensors file - :return: dictionary of safetensors data + :param by_layers: if True, return a iterator with dictionary of safetensors + data by layers + :return: Iterator of dictionary of safetensors data or iterator of + dictionaries by layers """ with safe_open(file_path, framework="pt", device="cpu") as f: - return {key: f.get_tensor(key) for key in f.keys()} + if by_layers: + current_layer = None + layer_data = {} + for key in sorted(f.keys()): + layer_name, param_name = key.split(".", 1) + if current_layer is None: + current_layer = layer_name + elif layer_name != current_layer: + yield current_layer, layer_data + current_layer = layer_name + layer_data = {} + layer_data[key] = f.get_tensor(key) + if layer_data: + yield layer_data + else: + yield {key: f.get_tensor(key) for key in f.keys()} From d1cda42837f6ab91e8e7e4e06242fc1683f857e1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 13 Jun 2024 15:08:03 +0000 Subject: [PATCH 4/6] allow passing in kwargs --- src/compressed_tensors/utils/converters/converters.py | 8 ++++++-- src/compressed_tensors/utils/converters/main.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index 57898e88..ebe106a5 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -56,7 +56,9 @@ def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: return new_state_dict @classmethod - def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: + def convert_from_safetensors( + cls, filepath: str, save_dir: str = None, **kwargs + ) -> str: """ Convert a .safetensors file or directory of .safetensors files, applying transformations to the state_dict and saving the new state_dict to a new @@ -86,7 +88,9 @@ def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: file, by_layers=True ) for layer_state_dict in state_dict: - new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + new_state_dict.update( + cls.translate(state_dict=layer_state_dict, **kwargs) + ) if new_state_dict: save_file( diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py index 122400a3..3089849c 100644 --- a/src/compressed_tensors/utils/converters/main.py +++ b/src/compressed_tensors/utils/converters/main.py @@ -19,19 +19,22 @@ __all__ = ["convert_autogptq_checkpoint"] -def convert_autogptq_checkpoint(old_checkpoint_path, new_checkpoint_path) -> str: +def convert_autogptq_checkpoint( + old_checkpoint_path, new_checkpoint_path, **kwargs +) -> str: """ Convert an autogptq checkpoint to a compressed tensor checkpoint :param old_checkpoint_path: the path to the autogptq checkpoint :param new_checkpoint_path: the path to save the converted compressed tensor checkpoint + :param kwargs: additional arguments to pass to the transformations :return: the path to the new checkpoint """ converter: BaseConverter = BaseConverter.load_from_registry( ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR ) checkpoint_path = converter.convert_from_safetensors( - old_checkpoint_path, new_checkpoint_path + old_checkpoint_path, new_checkpoint_path, **kwargs ) return checkpoint_path From 343aa8d8e7daa78e129382cc964149bc5de210e1 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 14 Jun 2024 13:25:09 +0000 Subject: [PATCH 5/6] Add progress --- .../utils/converters/converters.py | 63 +++++++++++++++---- .../utils/converters/main.py | 3 +- .../utils/converters/transformations.py | 4 +- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index ebe106a5..27349a40 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -28,14 +28,16 @@ ) from safetensors import safe_open from safetensors.torch import save_file +from tqdm import tqdm StateDictType = Union[Dict[str, torch.Tensor], str, Path] TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] + _LOGGER: logging.Logger = logging.getLogger(__name__) -class ConverterNames(Enum): +class ConverterNames(str, Enum): EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor" @@ -73,7 +75,7 @@ def convert_from_safetensors( filepath_: Path = Path(filepath) if not save_dir: - save_dir = "compressed_tensors_model" + save_dir: str = "compressed_tensors_model" save_dir_: Path = Path(save_dir) save_dir_.mkdir(exist_ok=True, parents=True) @@ -81,13 +83,19 @@ def convert_from_safetensors( metadata = {"format": "pt", "source": "Created by SparseML"} # transform and save the state_dict if filepath_.is_dir(): + tqdm.write(f"Converting directory: {filepath}") + tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files") for file in filepath_.glob("*.safetensors"): - _LOGGER.info(f"Loading file: {file}") + tqdm.write(f"Converting file: {file.name}") new_state_dict = {} state_dict: Iterable[StateDictType] = load_safetensors_state_dict( file, by_layers=True ) - for layer_state_dict in state_dict: + layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers") + for layer_state_dict in layer_progress_bar: + layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")] + layer_progress_bar.set_description(f"Converting layer {layer_name}") + layer_progress_bar.update() new_state_dict.update( cls.translate(state_dict=layer_state_dict, **kwargs) ) @@ -126,7 +134,7 @@ def transformations(cls) -> Iterable[TransformationType]: raise NotImplementedError() -@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR.value) +@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR) class ExllamaToCompressedTensorConverter(BaseConverter): """ A converter that applies transformations to the state_dict of a autogptq @@ -183,16 +191,49 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path): :param source_dir: The directory containing the original config.json file :param dest_dir: The directory to save the updated config.json file """ - from sparseml.transformers import SparseAutoConfig + from transformers import AutoConfig - config = SparseAutoConfig.from_pretrained(source_dir) + config = AutoConfig.from_pretrained(source_dir) if hasattr(config, "quantization_config"): _LOGGER.info("Updating quantization config...") - delattr(config, "quantization_config") + quantization_config = config.quantization_config + config.quantization_config = _convert_to_compressed_tensors_config(quantization_config) config.save_pretrained(dest_dir) +def _convert_to_compressed_tensors_config(quantization_config): + """ + Converts the quantization_config attribute from a config.json file + to a dictionary + + :param quantization_config: The quantization_config attribute from a config.json file + :return: The quantization_config as a dictionary + """ + compressed_tensor_config = ... + return compressed_tensor_config + +def layer_count(file_path: str) -> int: + """ + Count the number of layers in a safetensors file + + :param file_path: path to the safetensors file + :return: number of layers in the safetensors file + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + keys = sorted(f.keys()) + + last_layer_name = None + layer_count = 0 + for key in keys: + layer_name = key[:len("model.layers.0")] + if layer_name != last_layer_name: + last_layer_name = layer_name + layer_count += 1 + return layer_count + + + def load_safetensors_state_dict( file_path: str, by_layers: bool = True ) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]: @@ -201,7 +242,7 @@ def load_safetensors_state_dict( :param file_path: path to the safetensors file :param by_layers: if True, return a iterator with dictionary of safetensors - data by layers + data by layers. Default is True :return: Iterator of dictionary of safetensors data or iterator of dictionaries by layers """ @@ -210,11 +251,11 @@ def load_safetensors_state_dict( current_layer = None layer_data = {} for key in sorted(f.keys()): - layer_name, param_name = key.split(".", 1) + layer_name = key[:len("model.layers.0")] if current_layer is None: current_layer = layer_name elif layer_name != current_layer: - yield current_layer, layer_data + yield layer_data current_layer = layer_name layer_data = {} layer_data[key] = f.get_tensor(key) diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py index 3089849c..1f04fc41 100644 --- a/src/compressed_tensors/utils/converters/main.py +++ b/src/compressed_tensors/utils/converters/main.py @@ -15,12 +15,11 @@ from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames - __all__ = ["convert_autogptq_checkpoint"] def convert_autogptq_checkpoint( - old_checkpoint_path, new_checkpoint_path, **kwargs + old_checkpoint_path, new_checkpoint_path ,**kwargs ) -> str: """ Convert an autogptq checkpoint to a compressed tensor checkpoint diff --git a/src/compressed_tensors/utils/converters/transformations.py b/src/compressed_tensors/utils/converters/transformations.py index 86371c8e..fae69137 100644 --- a/src/compressed_tensors/utils/converters/transformations.py +++ b/src/compressed_tensors/utils/converters/transformations.py @@ -29,9 +29,9 @@ def _log_transformation(func): @functools.wraps(func) def wrapper(*args, **kwargs): - _LOGGER.info("Applying transformation: %s", func.__name__.upper()) + _LOGGER.debug("Applying transformation: %s", func.__name__.upper()) return_value = func(*args, **kwargs) - _LOGGER.info("Transformation: %s complete", func.__name__.upper()) + _LOGGER.debug("Transformation: %s complete", func.__name__.upper()) return return_value return wrapper From ed95dfc8357123a6f02e1587ce95ab560e2edf67 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 27 Jun 2024 14:14:27 +0000 Subject: [PATCH 6/6] Move validate safetensors function Add tests for the same --- .../utils/converters/converters.py | 86 ++++++++++--------- .../utils/converters/main.py | 5 +- .../utils/converters/transformations.py | 34 ++++++-- .../utils/safetensors_load.py | 24 ++++++ tests/test_utils/test_safetensors_load.py | 63 ++++++++++++++ 5 files changed, 161 insertions(+), 51 deletions(-) create mode 100644 tests/test_utils/test_safetensors_load.py diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py index 27349a40..bbedc7d2 100644 --- a/src/compressed_tensors/utils/converters/converters.py +++ b/src/compressed_tensors/utils/converters/converters.py @@ -23,9 +23,11 @@ import torch from compressed_tensors.registry.registry import RegistryMixin from compressed_tensors.utils.converters.transformations import ( + remove_unused_tensors, transform_autogptq_weights_and_reshape_tensors, transform_exllama_names, ) +from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path from safetensors import safe_open from safetensors.torch import save_file from tqdm import tqdm @@ -38,7 +40,7 @@ class ConverterNames(str, Enum): - EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor" + AutoGPTQConverter: str = "exllama_to_compressed_tensor" class BaseConverter(ABC, RegistryMixin): @@ -71,7 +73,7 @@ def convert_from_safetensors( :param save_dir: The directory to save the converted state_dict to :return: The directory where the converted state_dict was saved """ - _validate_safetensors_file_path(filepath) + validate_safetensors_file_path(filepath) filepath_: Path = Path(filepath) if not save_dir: @@ -84,16 +86,23 @@ def convert_from_safetensors( # transform and save the state_dict if filepath_.is_dir(): tqdm.write(f"Converting directory: {filepath}") - tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files") + tqdm.write( + f"Found: {len(list(filepath_.glob('*.safetensors')))} " + ".safetensors files" + ) for file in filepath_.glob("*.safetensors"): tqdm.write(f"Converting file: {file.name}") new_state_dict = {} state_dict: Iterable[StateDictType] = load_safetensors_state_dict( file, by_layers=True ) - layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers") + layer_progress_bar = tqdm( + state_dict, total=layer_count(file), desc="Converting layers" + ) for layer_state_dict in layer_progress_bar: - layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")] + layer_name = list(layer_state_dict.keys())[0][ + : len("model.layers.0") + ] layer_progress_bar.set_description(f"Converting layer {layer_name}") layer_progress_bar.update() new_state_dict.update( @@ -101,13 +110,18 @@ def convert_from_safetensors( ) if new_state_dict: + # compress before saving + # compressor = Compressor.load_from_registry( + # name=CompressionFormat.pack_quantized.value + # ) + # new_state_dict = compressor.compress(new_state_dict) save_file( new_state_dict, filename=save_dir_ / file.name, metadata=metadata, ) _copy_non_safetensor_files_(filepath_, save_dir_) - _update_quantization_config(filepath_, save_dir_) + # _update_quantization_config(filepath_, save_dir_) elif filepath_.is_file(): new_state_dict = {} @@ -134,39 +148,28 @@ def transformations(cls) -> Iterable[TransformationType]: raise NotImplementedError() -@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR) -class ExllamaToCompressedTensorConverter(BaseConverter): +@BaseConverter.register(name=ConverterNames.AutoGPTQConverter) +class AutoGPTQConverter(BaseConverter): """ A converter that applies transformations to the state_dict of a autogptq - quantized model to convert it to a compressed tensor model, which can be - loaded by the SparseAutoModel classes - """ - - @classmethod - def transformations(cls): - return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names) + quantized model to convert it to a compressed tensor model + Transformations made: -def _validate_safetensors_file_path(filepath: str): - """ - Given a file path, it is valid if: - - The file exists - - The file is either a single .safetensors file or a - directory containing .safetensors files - - :param filepath: A string file path to validate + -> Unpack autogptq 4 bit weight packing + -> Translate exllama names to compressed tensor names + -> Pack 4 bit weights with compressed tensor format + -> Remove unused tensors + -> Update quantization config in config.json file """ - filepath_: Path = Path(filepath) - - if not filepath_.exists(): - raise FileNotFoundError(f"File not found: {filepath}") - - if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): - raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") - - if filepath_.is_file() and not filepath_.suffix == ".safetensors": - raise ValueError(f"File must be a .safetensors file: {filepath}") + @classmethod + def transformations(cls): + return ( + transform_autogptq_weights_and_reshape_tensors, + transform_exllama_names, + remove_unused_tensors, + ) def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path): @@ -178,7 +181,7 @@ def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path): :param dest_dir: The directory to copy files to """ for file in source_dir.glob("*"): - if file.suffix != ".safetensors": + if file.suffix != ".safetensors" and file.name != "config.json": _LOGGER.info(f"Copying file: {file} to {dest_dir}") shutil.copy(file, dest_dir / file.name) @@ -198,7 +201,9 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path): if hasattr(config, "quantization_config"): _LOGGER.info("Updating quantization config...") quantization_config = config.quantization_config - config.quantization_config = _convert_to_compressed_tensors_config(quantization_config) + config.quantization_config = _convert_to_compressed_tensors_config( + quantization_config + ) config.save_pretrained(dest_dir) @@ -207,12 +212,14 @@ def _convert_to_compressed_tensors_config(quantization_config): Converts the quantization_config attribute from a config.json file to a dictionary - :param quantization_config: The quantization_config attribute from a config.json file + :param quantization_config: The quantization_config + attribute from a config.json file :return: The quantization_config as a dictionary """ compressed_tensor_config = ... return compressed_tensor_config + def layer_count(file_path: str) -> int: """ Count the number of layers in a safetensors file @@ -222,16 +229,15 @@ def layer_count(file_path: str) -> int: """ with safe_open(file_path, framework="pt", device="cpu") as f: keys = sorted(f.keys()) - + last_layer_name = None layer_count = 0 for key in keys: - layer_name = key[:len("model.layers.0")] + layer_name = key[: len("model.layers.0")] if layer_name != last_layer_name: last_layer_name = layer_name layer_count += 1 return layer_count - def load_safetensors_state_dict( @@ -251,7 +257,7 @@ def load_safetensors_state_dict( current_layer = None layer_data = {} for key in sorted(f.keys()): - layer_name = key[:len("model.layers.0")] + layer_name = key[: len("model.layers.0")] if current_layer is None: current_layer = layer_name elif layer_name != current_layer: diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py index 1f04fc41..dd7516b4 100644 --- a/src/compressed_tensors/utils/converters/main.py +++ b/src/compressed_tensors/utils/converters/main.py @@ -15,11 +15,12 @@ from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames + __all__ = ["convert_autogptq_checkpoint"] def convert_autogptq_checkpoint( - old_checkpoint_path, new_checkpoint_path ,**kwargs + old_checkpoint_path, new_checkpoint_path, **kwargs ) -> str: """ Convert an autogptq checkpoint to a compressed tensor checkpoint @@ -31,7 +32,7 @@ def convert_autogptq_checkpoint( :return: the path to the new checkpoint """ converter: BaseConverter = BaseConverter.load_from_registry( - ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR + ConverterNames.AutoGPTQConverter ) checkpoint_path = converter.convert_from_safetensors( old_checkpoint_path, new_checkpoint_path, **kwargs diff --git a/src/compressed_tensors/utils/converters/transformations.py b/src/compressed_tensors/utils/converters/transformations.py index fae69137..1bebfdae 100644 --- a/src/compressed_tensors/utils/converters/transformations.py +++ b/src/compressed_tensors/utils/converters/transformations.py @@ -69,7 +69,7 @@ def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: name_map: Dict[str, str] = { ".scales": ".weight_scale", ".qzeros": ".weight_zero_point", - ".qweight": ".weight", + ".qweight": ".weight_packed", } updated_state_dict = {} @@ -91,7 +91,7 @@ def transform_autogptq_weights_and_reshape_tensors( to CompressedTensors conversion The transformations include: - - Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors + - Unpack and dequantize the weight tensor using the scales, zeros, and g_idx tensors - Squeeze the scales tensor to [x] from [1, x] :pre-condition: The state_dict should be for a quantized model @@ -117,13 +117,15 @@ def transform_autogptq_weights_and_reshape_tensors( g_idx = state_dict[key.replace("qweight", "g_idx")] zeros = unpack_zeros(qzeros) - qweight = unpack_int32_into_fp32( - qweight=tensor, - scales=scales, - zeros=zeros, - g_idx=g_idx, - ) - transformed_weights_dict[key] = qweight + # qweight = unpack_int32_into_fp32( + # qweight=tensor, + # scales=scales, + # zeros=zeros, + # g_idx=g_idx, + # ) + new_shape = torch.tensor([tensor.shape[0] * 8, tensor.shape[1]]) + transformed_weights_dict[key] = tensor + transformed_weights_dict[key.replace("qweight", "weight_shape")] = new_shape # transform scales for key, tensor in state_dict.items(): @@ -222,3 +224,17 @@ def unpack_int32_into_fp32( weight = torch.cat(weight, dim=1) return weight + + +def remove_unused_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Remove unused tensors from the state_dict + + :param state_dict: The state_dict to be cleaned + :return: The cleaned state_dict + """ + return { + key: tensor + for key, tensor in state_dict.items() + if is_gptq_quantization_target(key) and not key.endswith(".g_idx") + } diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 9cdac782..5c229c39 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,6 +16,7 @@ import os import re import struct +from pathlib import Path from typing import Dict, List, Optional from safetensors import safe_open @@ -32,6 +33,7 @@ "get_nested_weight_mappings", "get_quantization_state_dict", "is_quantization_param", + "validate_safetensors_file_path", ] @@ -236,3 +238,25 @@ def is_quantization_param(name: str) -> bool: return True return False + + +def validate_safetensors_file_path(filepath: str): + """ + Given a file path, it is valid if: + - The file exists + - The file is either a single .safetensors file or a + directory containing .safetensors files + + :param filepath: A string file path to validate + """ + + filepath_: Path = Path(filepath) + + if not filepath_.exists(): + raise FileNotFoundError(f"File not found: {filepath}") + + if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): + raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") + + if filepath_.is_file() and not filepath_.suffix == ".safetensors": + raise ValueError(f"File must be a .safetensors file: {filepath}") diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..df53796e --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path + + +@pytest.fixture +def temp_dir(tmp_path): + return tmp_path / "subdirectory" + + +@pytest.fixture +def safetensors_file(temp_dir): + temp_dir.mkdir(exists_ok=True) + safetensors_filepath = temp_dir / "test.safetensors" + safetensors_filepath.write_text("content") + return safetensors_filepath + + +@pytest.fixture +def non_safetensors_file(temp_dir): + temp_dir.mkdir(exists_ok=True) + non_safetensors_filepath = temp_dir / "test.txt" + non_safetensors_filepath.write_text("content") + return non_safetensors_filepath + + +def test_validate_safetensors_file_path_file_not_found(): + with pytest.raises(FileNotFoundError): + validate_safetensors_file_path("nonexistent_file.safetensors") + + +def test_validate_safetensors_file_path_no_safetensors_files_in_directory(temp_dir): + temp_dir.mkdir() + with pytest.raises(FileNotFoundError): + validate_safetensors_file_path(str(temp_dir)) + + +def test_validate_safetensors_file_path_file_is_not_safetensors(non_safetensors_file): + with pytest.raises(ValueError): + validate_safetensors_file_path(str(non_safetensors_file)) + + +def test_validate_safetensors_file_path_valid_safetensors_file(safetensors_file): + validate_safetensors_file_path(str(safetensors_file)) + + +def test_validate_safetensors_file_path_valid_directory_with_safetensors_files( + temp_dir, safetensors_file +): + validate_safetensors_file_path(str(temp_dir))