Skip to content

Commit

Permalink
Add progress
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jun 14, 2024
1 parent ea731b4 commit 9fb97cd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
63 changes: 52 additions & 11 deletions src/compressed_tensors/utils/converters/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
)
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm


StateDictType = Union[Dict[str, torch.Tensor], str, Path]
TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]

_LOGGER: logging.Logger = logging.getLogger(__name__)


class ConverterNames(Enum):
class ConverterNames(str, Enum):
EXLLAMA_TO_COMPRESSED_TENSOR = "exllama_to_compressed_tensor"


Expand Down Expand Up @@ -73,21 +75,27 @@ def convert_from_safetensors(

filepath_: Path = Path(filepath)
if not save_dir:
save_dir = "compressed_tensors_model"
save_dir: str = "compressed_tensors_model"

save_dir_: Path = Path(save_dir)
save_dir_.mkdir(exist_ok=True, parents=True)

metadata = {"format": "pt", "source": "Created by SparseML"}
# transform and save the state_dict
if filepath_.is_dir():
tqdm.write(f"Converting directory: {filepath}")
tqdm.write(f"Found: {len(list(filepath_.glob('*.safetensors')))} .safetensors files")
for file in filepath_.glob("*.safetensors"):
_LOGGER.info(f"Loading file: {file}")
tqdm.write(f"Converting file: {file.name}")
new_state_dict = {}
state_dict: Iterable[StateDictType] = load_safetensors_state_dict(
file, by_layers=True
)
for layer_state_dict in state_dict:
layer_progress_bar = tqdm(state_dict, total=layer_count(file), desc="Converting layers")
for layer_state_dict in layer_progress_bar:
layer_name = list(layer_state_dict.keys())[0][:len("model.layers.0")]
layer_progress_bar.set_description(f"Converting layer {layer_name}")
layer_progress_bar.update()
new_state_dict.update(
cls.translate(state_dict=layer_state_dict, **kwargs)
)
Expand Down Expand Up @@ -126,7 +134,7 @@ def transformations(cls) -> Iterable[TransformationType]:
raise NotImplementedError()


@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR.value)
@BaseConverter.register(name=ConverterNames.EXLLAMA_TO_COMPRESSED_TENSOR)
class ExllamaToCompressedTensorConverter(BaseConverter):
"""
A converter that applies transformations to the state_dict of a autogptq
Expand Down Expand Up @@ -183,16 +191,49 @@ def _update_quantization_config(source_dir: Path, dest_dir: Path):
:param source_dir: The directory containing the original config.json file
:param dest_dir: The directory to save the updated config.json file
"""
from sparseml.transformers import SparseAutoConfig
from transformers import AutoConfig

config = SparseAutoConfig.from_pretrained(source_dir)
config = AutoConfig.from_pretrained(source_dir)

if hasattr(config, "quantization_config"):
_LOGGER.info("Updating quantization config...")
delattr(config, "quantization_config")
quantization_config = config.quantization_config
config.quantization_config = _convert_to_compressed_tensors_config(quantization_config)
config.save_pretrained(dest_dir)


def _convert_to_compressed_tensors_config(quantization_config):
"""
Converts the quantization_config attribute from a config.json file
to a dictionary
:param quantization_config: The quantization_config attribute from a config.json file
:return: The quantization_config as a dictionary
"""
compressed_tensor_config = ...
return compressed_tensor_config

def layer_count(file_path: str) -> int:
"""
Count the number of layers in a safetensors file
:param file_path: path to the safetensors file
:return: number of layers in the safetensors file
"""
with safe_open(file_path, framework="pt", device="cpu") as f:
keys = sorted(f.keys())

last_layer_name = None
layer_count = 0
for key in keys:
layer_name = key[:len("model.layers.0")]
if layer_name != last_layer_name:
last_layer_name = layer_name
layer_count += 1
return layer_count



def load_safetensors_state_dict(
file_path: str, by_layers: bool = True
) -> Iterator[Tuple[str, Dict[str, torch.Tensor]]]:
Expand All @@ -201,7 +242,7 @@ def load_safetensors_state_dict(
:param file_path: path to the safetensors file
:param by_layers: if True, return a iterator with dictionary of safetensors
data by layers
data by layers. Default is True
:return: Iterator of dictionary of safetensors data or iterator of
dictionaries by layers
"""
Expand All @@ -210,11 +251,11 @@ def load_safetensors_state_dict(
current_layer = None
layer_data = {}
for key in sorted(f.keys()):
layer_name, param_name = key.split(".", 1)
layer_name = key[:len("model.layers.0")]
if current_layer is None:
current_layer = layer_name
elif layer_name != current_layer:
yield current_layer, layer_data
yield layer_data
current_layer = layer_name
layer_data = {}
layer_data[key] = f.get_tensor(key)
Expand Down
3 changes: 1 addition & 2 deletions src/compressed_tensors/utils/converters/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@

from compressed_tensors.utils.converters.converters import BaseConverter, ConverterNames


__all__ = ["convert_autogptq_checkpoint"]


def convert_autogptq_checkpoint(
old_checkpoint_path, new_checkpoint_path, **kwargs
old_checkpoint_path, new_checkpoint_path ,**kwargs
) -> str:
"""
Convert an autogptq checkpoint to a compressed tensor checkpoint
Expand Down
4 changes: 2 additions & 2 deletions src/compressed_tensors/utils/converters/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
def _log_transformation(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_LOGGER.info("Applying transformation: %s", func.__name__.upper())
_LOGGER.debug("Applying transformation: %s", func.__name__.upper())
return_value = func(*args, **kwargs)
_LOGGER.info("Transformation: %s complete", func.__name__.upper())
_LOGGER.debug("Transformation: %s complete", func.__name__.upper())
return return_value

return wrapper
Expand Down

0 comments on commit 9fb97cd

Please sign in to comment.