Skip to content

Commit

Permalink
Move validate safetensors function
Browse files Browse the repository at this point in the history
Add tests for the same
  • Loading branch information
rahul-tuli committed Jun 27, 2024
1 parent 9fb97cd commit 282d96e
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 51 deletions.
86 changes: 46 additions & 40 deletions src/compressed_tensors/utils/converters/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import torch
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils.converters.transformations import (
remove_unused_tensors,
transform_autogptq_weights_and_reshape_tensors,
transform_exllama_names,
)
from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
Expand All @@ -38,7 +40,7 @@


class ConverterNames(str, Enum):
EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor"
AutoGPTQConverter: str = "exllama_to_compressed_tensor"


class BaseConverter(ABC, RegistryMixin):
Expand Down Expand Up @@ -71,7 +73,7 @@ def convert_from_safetensors(
:param save_dir: The directory to save the converted state_dict to
:return: The directory where the converted state_dict was saved
"""
_validate_safetensors_file_path(filepath)
validate_safetensors_file_path(filepath)

filepath_: Path = Path(filepath)
if not save_dir:
Expand All @@ -84,30 +86,42 @@ def convert_from_safetensors(
# transform and save the state_dict
if filepath_.is_dir():
tqdm.write(f"Converting directory: {filepath}")
tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files")
tqdm.write(
f"Found: {len(list(filepath_.glob('*.safetensors')))} "
".safetensors files"
)
for file in filepath_.glob("*.safetensors"):
tqdm.write(f"Converting file: {file.name}")
new_state_dict = {}
state_dict: Iterable[StateDictType] = load_safetensors_state_dict(
file, by_layers=True
)
layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers")
layer_progress_bar = tqdm(
state_dict, total=layer_count(file), desc="Converting layers"
)
for layer_state_dict in layer_progress_bar:
layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")]
layer_name = list(layer_state_dict.keys())[0][
: len("model.layers.0")
]
layer_progress_bar.set_description(f"Converting layer {layer_name}")
layer_progress_bar.update()
new_state_dict.update(
cls.translate(state_dict=layer_state_dict, **kwargs)
)

if new_state_dict:
# compress before saving
# compressor = Compressor.load_from_registry(
# name=CompressionFormat.pack_quantized.value
# )
# new_state_dict = compressor.compress(new_state_dict)
save_file(
new_state_dict,
filename=save_dir_ / file.name,
metadata=metadata,
)
_copy_non_safetensor_files_(filepath_, save_dir_)
_update_quantization_config(filepath_, save_dir_)
# _update_quantization_config(filepath_, save_dir_)

elif filepath_.is_file():
new_state_dict = {}
Expand All @@ -134,39 +148,28 @@ def transformations(cls) -> Iterable[TransformationType]:
raise NotImplementedError()


@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR)
class ExllamaToCompressedTensorConverter(BaseConverter):
@BaseConverter.register(name=ConverterNames.AutoGPTQConverter)
class AutoGPTQConverter(BaseConverter):
"""
A converter that applies transformations to the state_dict of a autogptq
quantized model to convert it to a compressed tensor model, which can be
loaded by the SparseAutoModel classes
"""

@classmethod
def transformations(cls):
return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names)
quantized model to convert it to a compressed tensor model
Transformations made:
def _validate_safetensors_file_path(filepath: str):
"""
Given a file path, it is valid if:
- The file exists
- The file is either a single .safetensors file or a
directory containing .safetensors files
:param filepath: A string file path to validate
-> Unpack autogptq 4 bit weight packing
-> Translate exllama names to compressed tensor names
-> Pack 4 bit weights with compressed tensor format
-> Remove unused tensors
-> Update quantization config in config.json file
"""

filepath_: Path = Path(filepath)

if not filepath_.exists():
raise FileNotFoundError(f"File not found: {filepath}")

if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")

if filepath_.is_file() and not filepath_.suffix == ".safetensors":
raise ValueError(f"File must be a .safetensors file: {filepath}")
@classmethod
def transformations(cls):
return (
transform_autogptq_weights_and_reshape_tensors,
transform_exllama_names,
remove_unused_tensors,
)


def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path):
Expand All @@ -178,7 +181,7 @@ def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path):
:param dest_dir: The directory to copy files to
"""
for file in source_dir.glob("*"):
if file.suffix != ".safetensors":
if file.suffix != ".safetensors" and file.name != "config.json":
_LOGGER.info(f"Copying file: {file} to {dest_dir}")
shutil.copy(file, dest_dir / file.name)

Expand All @@ -198,7 +201,9 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path):
if hasattr(config, "quantization_config"):
_LOGGER.info("Updating quantization config...")
quantization_config = config.quantization_config
config.quantization_config = _convert_to_compressed_tensors_config(quantization_config)
config.quantization_config = _convert_to_compressed_tensors_config(
quantization_config
)
config.save_pretrained(dest_dir)


Expand All @@ -207,12 +212,14 @@ def _convert_to_compressed_tensors_config(quantization_config):
Converts the quantization_config attribute from a config.json file
to a dictionary
:param quantization_config: The quantization_config attribute from a config.json file
:param quantization_config: The quantization_config
attribute from a config.json file
:return: The quantization_config as a dictionary
"""
compressed_tensor_config = ...
return compressed_tensor_config


def layer_count(file_path: str) -> int:
"""
Count the number of layers in a safetensors file
Expand All @@ -222,16 +229,15 @@ def layer_count(file_path: str) -> int:
"""
with safe_open(file_path, framework="pt", device="cpu") as f:
keys = sorted(f.keys())

last_layer_name = None
layer_count = 0
for key in keys:
layer_name = key[:len("model.layers.0")]
layer_name = key[: len("model.layers.0")]
if layer_name != last_layer_name:
last_layer_name = layer_name
layer_count += 1
return layer_count



def load_safetensors_state_dict(
Expand All @@ -251,7 +257,7 @@ def load_safetensors_state_dict(
current_layer = None
layer_data = {}
for key in sorted(f.keys()):
layer_name = key[:len("model.layers.0")]
layer_name = key[: len("model.layers.0")]
if current_layer is None:
current_layer = layer_name
elif layer_name != current_layer:
Expand Down
5 changes: 3 additions & 2 deletions src/compressed_tensors/utils/converters/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames


__all__ = ["convert_autogptq_checkpoint"]


def convert_autogptq_checkpoint(
old_checkpoint_path, new_checkpoint_path ,**kwargs
old_checkpoint_path, new_checkpoint_path, **kwargs
) -> str:
"""
Convert an autogptq checkpoint to a compressed tensor checkpoint
Expand All @@ -31,7 +32,7 @@ def convert_autogptq_checkpoint(
:return: the path to the new checkpoint
"""
converter: BaseConverter = BaseConverter.load_from_registry(
ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR
ConverterNames.AutoGPTQConverter
)
checkpoint_path = converter.convert_from_safetensors(
old_checkpoint_path, new_checkpoint_path, **kwargs
Expand Down
34 changes: 25 additions & 9 deletions src/compressed_tensors/utils/converters/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
name_map: Dict[str, str] = {
".scales": ".weight_scale",
".qzeros": ".weight_zero_point",
".qweight": ".weight",
".qweight": ".weight_packed",
}

updated_state_dict = {}
Expand All @@ -91,7 +91,7 @@ def transform_autogptq_weights_and_reshape_tensors(
to CompressedTensors conversion
The transformations include:
- Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors
- Unpack and dequantize the weight tensor using the scales, zeros, and g_idx tensors
- Squeeze the scales tensor to [x] from [1, x]
:pre-condition: The state_dict should be for a quantized model
Expand All @@ -117,13 +117,15 @@ def transform_autogptq_weights_and_reshape_tensors(
g_idx = state_dict[key.replace("qweight", "g_idx")]

zeros = unpack_zeros(qzeros)
qweight = unpack_int32_into_fp32(
qweight=tensor,
scales=scales,
zeros=zeros,
g_idx=g_idx,
)
transformed_weights_dict[key] = qweight
# qweight = unpack_int32_into_fp32(
# qweight=tensor,
# scales=scales,
# zeros=zeros,
# g_idx=g_idx,
# )
new_shape = torch.tensor([tensor.shape[0] * 8, tensor.shape[1]])
transformed_weights_dict[key] = tensor
transformed_weights_dict[key.replace("qweight", "weight_shape")] = new_shape

# transform scales
for key, tensor in state_dict.items():
Expand Down Expand Up @@ -222,3 +224,17 @@ def unpack_int32_into_fp32(
weight = torch.cat(weight, dim=1)

return weight


def remove_unused_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Remove unused tensors from the state_dict
:param state_dict: The state_dict to be cleaned
:return: The cleaned state_dict
"""
return {
key: tensor
for key, tensor in state_dict.items()
if is_gptq_quantization_target(key) and not key.endswith(".g_idx")
}
23 changes: 23 additions & 0 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import re
import struct
from pathlib import Path
from typing import Dict, List, Optional

from safetensors import safe_open
Expand Down Expand Up @@ -236,3 +237,25 @@ def is_quantization_param(name: str) -> bool:
return True

return False


def validate_safetensors_file_path(filepath: str):
"""
Given a file path, it is valid if:
- The file exists
- The file is either a single .safetensors file or a
directory containing .safetensors files
:param filepath: A string file path to validate
"""

filepath_: Path = Path(filepath)

if not filepath_.exists():
raise FileNotFoundError(f"File not found: {filepath}")

if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")):
raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}")

if filepath_.is_file() and not filepath_.suffix == ".safetensors":
raise ValueError(f"File must be a .safetensors file: {filepath}")
63 changes: 63 additions & 0 deletions tests/test_utils/test_safetensors_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path


@pytest.fixture
def temp_dir(tmp_path):
return tmp_path / "subdirectory"


@pytest.fixture
def safetensors_file(temp_dir):
temp_dir.mkdir(exists_ok=True)
safetensors_filepath = temp_dir / "test.safetensors"
safetensors_filepath.write_text("content")
return safetensors_filepath


@pytest.fixture
def non_safetensors_file(temp_dir):
temp_dir.mkdir(exists_ok=True)
non_safetensors_filepath = temp_dir / "test.txt"
non_safetensors_filepath.write_text("content")
return non_safetensors_filepath


def test_validate_safetensors_file_path_file_not_found():
with pytest.raises(FileNotFoundError):
validate_safetensors_file_path("nonexistent_file.safetensors")


def test_validate_safetensors_file_path_no_safetensors_files_in_directory(temp_dir):
temp_dir.mkdir()
with pytest.raises(FileNotFoundError):
validate_safetensors_file_path(str(temp_dir))


def test_validate_safetensors_file_path_file_is_not_safetensors(non_safetensors_file):
with pytest.raises(ValueError):
validate_safetensors_file_path(str(non_safetensors_file))


def test_validate_safetensors_file_path_valid_safetensors_file(safetensors_file):
validate_safetensors_file_path(str(safetensors_file))


def test_validate_safetensors_file_path_valid_directory_with_safetensors_files(
temp_dir, safetensors_file
):
validate_safetensors_file_path(str(temp_dir))

0 comments on commit 282d96e

Please sign in to comment.