diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 263796eb..739dac0c 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -190,7 +190,7 @@ def _process_quantization( if columns >= group_size: if columns % group_size != 0: raise ValueError( - "tesnor column shape must be divisble " + "tensor column shape must be divisible " f"by the given group_size {group_size}" ) for i in range(ceil(columns / group_size)): diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index 5bc0fec2..f0b267b7 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. # flake8: noqa +from .converters import * from .safetensors_load import * diff --git a/src/compressed_tensors/utils/converters/__init__.py b/src/compressed_tensors/utils/converters/__init__.py new file mode 100644 index 00000000..540036b0 --- /dev/null +++ b/src/compressed_tensors/utils/converters/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .main import * diff --git a/src/compressed_tensors/utils/converters/converters.py b/src/compressed_tensors/utils/converters/converters.py new file mode 100644 index 00000000..bbedc7d2 --- /dev/null +++ b/src/compressed_tensors/utils/converters/converters.py @@ -0,0 +1,271 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import shutil +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path +from typing import Callable, Dict, Iterable, Iterator, Tuple, Union + +import torch +from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.utils.converters.transformations import ( + remove_unused_tensors, + transform_autogptq_weights_and_reshape_tensors, + transform_exllama_names, +) +from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm + + +StateDictType = Union[Dict[str, torch.Tensor], str, Path] +TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class ConverterNames(str, Enum): + AutoGPTQConverter: str = "exllama_to_compressed_tensor" + + +class BaseConverter(ABC, RegistryMixin): + @classmethod + def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: + """ + Applies transformations to the state_dict + + :param state_dict: The state_dict to apply transformations to + :param kwargs: Additional arguments to pass to the transformations + :return: The transformed state_dict + """ + _LOGGER.info("Applying transformations...") + new_state_dict = copy.copy(state_dict) + for transformation in cls.transformations(): + new_state_dict = transformation(new_state_dict, **kwargs) + return new_state_dict + + @classmethod + def convert_from_safetensors( + cls, filepath: str, save_dir: str = None, **kwargs + ) -> str: + """ + Convert a .safetensors file or directory of .safetensors files, applying + transformations to the state_dict and saving the new state_dict to a new + directory + + :param filepath: The file path to the .safetensors file or directory + containing .safetensors files to convert + :param save_dir: The directory to save the converted state_dict to + :return: The directory where the converted state_dict was saved + """ + validate_safetensors_file_path(filepath) + + filepath_: Path = Path(filepath) + if not save_dir: + save_dir: str = "compressed_tensors_model" + + save_dir_: Path = Path(save_dir) + save_dir_.mkdir(exist_ok=True, parents=True) + + metadata = {"format": "pt", "source": "Created by SparseML"} + # transform and save the state_dict + if filepath_.is_dir(): + tqdm.write(f"Converting directory: {filepath}") + tqdm.write( + f"Found: {len(list(filepath_.glob('*.safetensors')))} " + ".safetensors files" + ) + for file in filepath_.glob("*.safetensors"): + tqdm.write(f"Converting file: {file.name}") + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True + ) + layer_progress_bar = tqdm( + state_dict, total=layer_count(file), desc="Converting layers" + ) + for layer_state_dict in layer_progress_bar: + layer_name = list(layer_state_dict.keys())[0][ + : len("model.layers.0") + ] + layer_progress_bar.set_description(f"Converting layer {layer_name}") + layer_progress_bar.update() + new_state_dict.update( + cls.translate(state_dict=layer_state_dict, **kwargs) + ) + + if new_state_dict: + # compress before saving + # compressor = Compressor.load_from_registry( + # name=CompressionFormat.pack_quantized.value + # ) + # new_state_dict = compressor.compress(new_state_dict) + save_file( + new_state_dict, + filename=save_dir_ / file.name, + metadata=metadata, + ) + _copy_non_safetensor_files_(filepath_, save_dir_) + # _update_quantization_config(filepath_, save_dir_) + + elif filepath_.is_file(): + new_state_dict = {} + state_dict: Iterable[StateDictType] = load_safetensors_state_dict( + file, by_layers=True + ) + for layer_state_dict in state_dict: + new_state_dict.update(cls.translate(state_dict=layer_state_dict)) + + save_file( + new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata + ) + + return str(save_dir_) + + @classmethod + @abstractmethod + def transformations(cls) -> Iterable[TransformationType]: + """ + Returns an iterable of transformations that are applied in the converter, + each transformation should be a callable that takes a state_dict and returns + a transformed state_dict + """ + raise NotImplementedError() + + +@BaseConverter.register(name=ConverterNames.AutoGPTQConverter) +class AutoGPTQConverter(BaseConverter): + """ + A converter that applies transformations to the state_dict of a autogptq + quantized model to convert it to a compressed tensor model + + Transformations made: + + -> Unpack autogptq 4 bit weight packing + -> Translate exllama names to compressed tensor names + -> Pack 4 bit weights with compressed tensor format + -> Remove unused tensors + -> Update quantization config in config.json file + """ + + @classmethod + def transformations(cls): + return ( + transform_autogptq_weights_and_reshape_tensors, + transform_exllama_names, + remove_unused_tensors, + ) + + +def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path): + """ + A helper function to copy all auxillary files in a directory that are + not .safetensors files, for example (config.json, recipe.yaml, ...) + + :param source_dir: The directory to copy files from + :param dest_dir: The directory to copy files to + """ + for file in source_dir.glob("*"): + if file.suffix != ".safetensors" and file.name != "config.json": + _LOGGER.info(f"Copying file: {file} to {dest_dir}") + shutil.copy(file, dest_dir / file.name) + + +def _update_quantization_config(source_dir: Path, dest_dir: Path): + """ + Updates config.json file in the destination directory by removing the + quantization_config attribute + + :param source_dir: The directory containing the original config.json file + :param dest_dir: The directory to save the updated config.json file + """ + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(source_dir) + + if hasattr(config, "quantization_config"): + _LOGGER.info("Updating quantization config...") + quantization_config = config.quantization_config + config.quantization_config = _convert_to_compressed_tensors_config( + quantization_config + ) + config.save_pretrained(dest_dir) + + +def _convert_to_compressed_tensors_config(quantization_config): + """ + Converts the quantization_config attribute from a config.json file + to a dictionary + + :param quantization_config: The quantization_config + attribute from a config.json file + :return: The quantization_config as a dictionary + """ + compressed_tensor_config = ... + return compressed_tensor_config + + +def layer_count(file_path: str) -> int: + """ + Count the number of layers in a safetensors file + + :param file_path: path to the safetensors file + :return: number of layers in the safetensors file + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + keys = sorted(f.keys()) + + last_layer_name = None + layer_count = 0 + for key in keys: + layer_name = key[: len("model.layers.0")] + if layer_name != last_layer_name: + last_layer_name = layer_name + layer_count += 1 + return layer_count + + +def load_safetensors_state_dict( + file_path: str, by_layers: bool = True +) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]: + """ + Load a safetensors file from disk + + :param file_path: path to the safetensors file + :param by_layers: if True, return a iterator with dictionary of safetensors + data by layers. Default is True + :return: Iterator of dictionary of safetensors data or iterator of + dictionaries by layers + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + if by_layers: + current_layer = None + layer_data = {} + for key in sorted(f.keys()): + layer_name = key[: len("model.layers.0")] + if current_layer is None: + current_layer = layer_name + elif layer_name != current_layer: + yield layer_data + current_layer = layer_name + layer_data = {} + layer_data[key] = f.get_tensor(key) + if layer_data: + yield layer_data + else: + yield {key: f.get_tensor(key) for key in f.keys()} diff --git a/src/compressed_tensors/utils/converters/main.py b/src/compressed_tensors/utils/converters/main.py new file mode 100644 index 00000000..dd7516b4 --- /dev/null +++ b/src/compressed_tensors/utils/converters/main.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames + + +__all__ = ["convert_autogptq_checkpoint"] + + +def convert_autogptq_checkpoint( + old_checkpoint_path, new_checkpoint_path, **kwargs +) -> str: + """ + Convert an autogptq checkpoint to a compressed tensor checkpoint + + :param old_checkpoint_path: the path to the autogptq checkpoint + :param new_checkpoint_path: the path to save the converted compressed + tensor checkpoint + :param kwargs: additional arguments to pass to the transformations + :return: the path to the new checkpoint + """ + converter: BaseConverter = BaseConverter.load_from_registry( + ConverterNames.AutoGPTQConverter + ) + checkpoint_path = converter.convert_from_safetensors( + old_checkpoint_path, new_checkpoint_path, **kwargs + ) + return checkpoint_path diff --git a/src/compressed_tensors/utils/converters/transformations.py b/src/compressed_tensors/utils/converters/transformations.py new file mode 100644 index 00000000..1bebfdae --- /dev/null +++ b/src/compressed_tensors/utils/converters/transformations.py @@ -0,0 +1,240 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa: F821 + +import functools +import logging +from typing import Dict + +import numpy +import numpy as np +import torch +from torch import Tensor + + +_LOGGER = logging.getLogger(__name__) + + +def _log_transformation(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _LOGGER.debug("Applying transformation: %s", func.__name__.upper()) + return_value = func(*args, **kwargs) + _LOGGER.debug("Transformation: %s complete", func.__name__.upper()) + return return_value + + return wrapper + + +def is_gptq_quantization_target(key: str) -> bool: + """ + Assumes self_attn and mlp are the only quantization targets + in model layers of the state_dict. + :param key: The key of the state_dict + :return: True if the key is a quantization target, False otherwise + """ + return "model.layers" in key and ("self_attn" in key or "mlp" in key) + + +@_log_transformation +def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Transforms the exallama state_dict keys to be compatible with + SparseAutoModel classes. + + The renames include: + - scales -> weight_scale + - qzeros -> weight_zero_point + - qweight -> weight + + Note: does not transforms the actual tensor values + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: Targets only the weights of the self_attn and mlp nodes + :param state_dict: The quantized state_dict to be transformed + :return: The transformed state_dict + """ + + name_map: Dict[str, str] = { + ".scales": ".weight_scale", + ".qzeros": ".weight_zero_point", + ".qweight": ".weight_packed", + } + + updated_state_dict = {} + for key, tensor in state_dict.items(): + if any(key.endswith(target_suffix := suffix) for suffix in name_map): + updated_key = key.replace(target_suffix, name_map[target_suffix]) + updated_state_dict[updated_key] = tensor + else: + updated_state_dict[key] = tensor + return updated_state_dict + + +@_log_transformation +def transform_autogptq_weights_and_reshape_tensors( + state_dict: Dict[str, Tensor] +) -> Dict[str, Tensor]: + """ + Tranforms weights into their required shapes and types for Exllama + to CompressedTensors conversion + + The transformations include: + - Unpack and dequantize the weight tensor using the scales, zeros, and g_idx tensors + - Squeeze the scales tensor to [x] from [1, x] + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: The state_dict should have the bias and g_idx tensors added + + :param state_dict: The state_dict to be transformed + :return: The transformed state_dict, with repacked and reshaped tensors + """ + + transformed_state_dict: Dict[str, Tensor] = {} + + # auxillary dict to store transformed weights + transformed_weights_dict: Dict[str, Tensor] = {} + + # quantize qweights before scales, and qzeros + # because the ordering in which tensors are fetched + # is not guaranteed by our implementation + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".qweight"): + # quantize the weight tensor + scales = state_dict[key.replace("qweight", "scales")] + qzeros = state_dict[key.replace("qweight", "qzeros")] + g_idx = state_dict[key.replace("qweight", "g_idx")] + + zeros = unpack_zeros(qzeros) + # qweight = unpack_int32_into_fp32( + # qweight=tensor, + # scales=scales, + # zeros=zeros, + # g_idx=g_idx, + # ) + new_shape = torch.tensor([tensor.shape[0] * 8, tensor.shape[1]]) + transformed_weights_dict[key] = tensor + transformed_weights_dict[key.replace("qweight", "weight_shape")] = new_shape + + # transform scales + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".scales"): + # scales [1, x] should be reshaped to [x] + scales = tensor.squeeze(0) + transformed_state_dict[key] = scales + else: + transformed_state_dict[key] = tensor + + # overwrite old weights with the new quantized weights + transformed_state_dict.update(transformed_weights_dict) + + # auxillary weights_dict not needed anymore + del transformed_weights_dict + + return transformed_state_dict + + +def unpack_zeros(qzeros): + """ + Unpack the quantized zero points tensor from 32 bit integers into 4 bit integers. + + :param qzeros: The quantized zero points tensor of int32 dtype and shape [1, 8x] + """ + bits = 4 + qzeros = qzeros.numpy().astype(np.uint32) + intzeros = np.zeros( + (qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32 + ) + + i = 0 + col = 0 + while col < intzeros.shape[1]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intzeros.shape[1])): + intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF + i += 32 // bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intzeros = intzeros.astype(np.int32) + intzeros = torch.from_numpy(intzeros) + + return intzeros + + +def unpack_int32_into_fp32( + qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor +) -> Tensor: + """ + Unpack the quantized weight tensor from 32 bit integers into 4 bit integers, + and then dequantize them using the scales, zeros, and g_idx tensors. + + :param qweight: The quantized weight tensor of int32 dtype and shape [x, y] + :param scales: The scales tensor + :param zeros: The zero points tensor + :param g_idx: The group index tensor + :return: The dequantized weight tensor of shape [x, 8y] + """ + bits = 4 + qweight = qweight.numpy().astype(numpy.uint32) + intweight = numpy.zeros( + (qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32 + ) + + i = 0 + row = 0 + while row < intweight.shape[0]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intweight.shape[0])): + intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF + i += 32 // bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intweight = torch.from_numpy(intweight.astype(numpy.int32)) + intweight = intweight.t().contiguous() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + scales = scales.clone().half() + + weight = [] + infeatures = intweight.shape[1] + for idx in range(infeatures): + weight.append( + ( + intweight[:, idx].float() * scales[:, g_idx[idx]] + - scale_zeros[:, g_idx[idx]] + )[:, None] + ) + weight = torch.cat(weight, dim=1) + + return weight + + +def remove_unused_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Remove unused tensors from the state_dict + + :param state_dict: The state_dict to be cleaned + :return: The cleaned state_dict + """ + return { + key: tensor + for key, tensor in state_dict.items() + if is_gptq_quantization_target(key) and not key.endswith(".g_idx") + } diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 9cdac782..5c229c39 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -16,6 +16,7 @@ import os import re import struct +from pathlib import Path from typing import Dict, List, Optional from safetensors import safe_open @@ -32,6 +33,7 @@ "get_nested_weight_mappings", "get_quantization_state_dict", "is_quantization_param", + "validate_safetensors_file_path", ] @@ -236,3 +238,25 @@ def is_quantization_param(name: str) -> bool: return True return False + + +def validate_safetensors_file_path(filepath: str): + """ + Given a file path, it is valid if: + - The file exists + - The file is either a single .safetensors file or a + directory containing .safetensors files + + :param filepath: A string file path to validate + """ + + filepath_: Path = Path(filepath) + + if not filepath_.exists(): + raise FileNotFoundError(f"File not found: {filepath}") + + if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): + raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") + + if filepath_.is_file() and not filepath_.suffix == ".safetensors": + raise ValueError(f"File must be a .safetensors file: {filepath}") diff --git a/tests/test_utils/converters/__init__.py b/tests/test_utils/converters/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/tests/test_utils/converters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_utils/converters/test_imports.py b/tests/test_utils/converters/test_imports.py new file mode 100644 index 00000000..0a3a6202 --- /dev/null +++ b/tests/test_utils/converters/test_imports.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + + +@pytest.fixture(scope="module") +def conversion_function(): + from compressed_tensors.utils.converters import convert_autogptq_checkpoint + + return convert_autogptq_checkpoint + + +@pytest.mark.parametrize( + "parent_module, function_name", + [ + ("compressed_tensors.utils", "convert_autogptq_checkpoint"), + ("compressed_tensors.utils.main", "convert_autogptq_checkpoint"), + ], +) +def test_convert_function_is_importable(parent_module, function_name): + import importlib + + module = importlib.import_module(parent_module) + assert hasattr( + module, function_name + ), f"{function_name} is not found in {parent_module}" + + +def test_conversion_function_accepts_correct_arguments(conversion_function): + import inspect + + sig = inspect.signature(conversion_function) + params = sig.parameters + assert ( + "old_checkpoint_path" in params + ), "Function does not accept 'old_checkpoint_path' argument" + assert ( + "new_checkpoint_path" in params + ), "Function does not accept 'new_checkpoint_path' argument" + + # check keyword arguments are also accepted + # (might be needed to configure specific transformations) + assert any( + param.kind == param.VAR_KEYWORD for param in params.values() + ), "Function does not accept **kwargs" diff --git a/tests/test_utils/test_safetensors_load.py b/tests/test_utils/test_safetensors_load.py new file mode 100644 index 00000000..df53796e --- /dev/null +++ b/tests/test_utils/test_safetensors_load.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from compressed_tensors.utils.safetensors_load import validate_safetensors_file_path + + +@pytest.fixture +def temp_dir(tmp_path): + return tmp_path / "subdirectory" + + +@pytest.fixture +def safetensors_file(temp_dir): + temp_dir.mkdir(exists_ok=True) + safetensors_filepath = temp_dir / "test.safetensors" + safetensors_filepath.write_text("content") + return safetensors_filepath + + +@pytest.fixture +def non_safetensors_file(temp_dir): + temp_dir.mkdir(exists_ok=True) + non_safetensors_filepath = temp_dir / "test.txt" + non_safetensors_filepath.write_text("content") + return non_safetensors_filepath + + +def test_validate_safetensors_file_path_file_not_found(): + with pytest.raises(FileNotFoundError): + validate_safetensors_file_path("nonexistent_file.safetensors") + + +def test_validate_safetensors_file_path_no_safetensors_files_in_directory(temp_dir): + temp_dir.mkdir() + with pytest.raises(FileNotFoundError): + validate_safetensors_file_path(str(temp_dir)) + + +def test_validate_safetensors_file_path_file_is_not_safetensors(non_safetensors_file): + with pytest.raises(ValueError): + validate_safetensors_file_path(str(non_safetensors_file)) + + +def test_validate_safetensors_file_path_valid_safetensors_file(safetensors_file): + validate_safetensors_file_path(str(safetensors_file)) + + +def test_validate_safetensors_file_path_valid_directory_with_safetensors_files( + temp_dir, safetensors_file +): + validate_safetensors_file_path(str(temp_dir))