diff --git a/src/sparsetensors/__init__.py b/src/sparsetensors/__init__.py index 4504af03..3eefa2c8 100644 --- a/src/sparsetensors/__init__.py +++ b/src/sparsetensors/__init__.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# flake8: noqa -SPARSITY_CONFIG_NAME = "sparsity_config" +from .base import * +# flake8: noqa from .compressors import * from .config import * +from .utils import * diff --git a/src/sparsetensors/base.py b/src/sparsetensors/base.py new file mode 100644 index 00000000..f01a055f --- /dev/null +++ b/src/sparsetensors/base.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SPARSITY_CONFIG_NAME = "sparsity_config" diff --git a/src/sparsetensors/compressors/__init__.py b/src/sparsetensors/compressors/__init__.py index e8a36527..1c7362eb 100644 --- a/src/sparsetensors/compressors/__init__.py +++ b/src/sparsetensors/compressors/__init__.py @@ -16,4 +16,4 @@ from .base import ModelCompressor from .dense import DenseCompressor -from .sparse_bitmask import BitmaskCompressor +from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/sparsetensors/compressors/base.py b/src/sparsetensors/compressors/base.py index 0487efd9..7b013827 100644 --- a/src/sparsetensors/compressors/base.py +++ b/src/sparsetensors/compressors/base.py @@ -15,13 +15,13 @@ import operator from typing import Dict, Generator, Tuple +from sparsetensors.base import SPARSITY_CONFIG_NAME +from sparsetensors.config import CompressionConfig from sparsezoo.utils.registry import RegistryMixin from torch import Tensor from torch.nn import Module, Parameter from tqdm import tqdm -from . import SPARSITY_CONFIG_NAME - __all__ = ["ModelCompressor"] @@ -33,7 +33,7 @@ class ModelCompressor(RegistryMixin): :param config: config specifying compression parameters """ - def __init__(self, config: "CompressionConfig"): # noqa + def __init__(self, config: CompressionConfig): self.config = config def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]: @@ -66,17 +66,21 @@ def replace_layer(param_name: str, data: Tensor, model: Module): :param model: pytorch model to insert data into """ model_device = operator.attrgetter(param_name)(model).device - set_layer(param_name, Parameter(data.to(model_device)), model) # noqa TODO - - def overwrite_weights(self, pretrained_model_name_or_path: str, model: Module): + new_param = Parameter(data.to(model_device)) + # TODO: Two for loops? + for name, param in model.named_parameters(): + if name == param_name: + param.data = new_param.data + return + + def overwrite_weights(self, model_path: str, model: Module): """ - Overwrites the weights in model with weights decompressed from - pretrained_model_name_or_path + Overwrites the weights in model with weights decompressed from model_path - :param pretrained_model_name_or_path: path to compressed weights + :param model_path: path to compressed weights :param model: pytorch model to load decompressed weights into """ - dense_gen = self.decompress(pretrained_model_name_or_path) + dense_gen = self.decompress(model_path) for name, data in tqdm(dense_gen, desc="Decompressing model"): ModelCompressor.replace_layer(name, data, model) setattr(model, SPARSITY_CONFIG_NAME, self.config) diff --git a/src/sparsetensors/compressors/utils/helpers.py b/src/sparsetensors/compressors/utils/helpers.py deleted file mode 100644 index 24863f0e..00000000 --- a/src/sparsetensors/compressors/utils/helpers.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -from sparsetensors.compressors import ModelCompressor -from sparsetensors.config import CompressionConfig -from torch.nn import Module -from transformers import AutoConfig - -from . import SPARSITY_CONFIG_NAME - - -__all__ = ["infer_compressor_from_model_config", "set_layer"] - - -def infer_compressor_from_model_config( - pretrained_model_name_or_path: str, -) -> Optional[ModelCompressor]: - """ - Given a path to a model config, extract a sparsity config if it exists and return - the associated ModelCompressor - - :param pretrained_model_name_or_path: path to model config on disk or HF hub - :return: matching compressor if config contains a sparsity config - """ - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) - if sparsity_config is None: - return None - - format = sparsity_config.get("format") - sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) - compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) - return compressor - - -def set_layer(target: str, layer: Module, module: Module) -> Module: - target = fix_fsdp_module_name(target) # noqa TODO - with summon_full_params_context(module): # noqa TODO - # importing here to avoid circular import - from sparseml.utils.fsdp.helpers import maybe_get_wrapped # noqa TODO - - parent_target = ".".join(target.split(".")[:-1]) - if parent_target != "": - parent_layer = get_layer(parent_target, module)[1] # noqa TODO - else: - parent_layer = maybe_get_wrapped(module) - old_layer = getattr(parent_layer, target.split(".")[-1]) - setattr(parent_layer, target.split(".")[-1], layer) - - return old_layer diff --git a/src/sparsetensors/config/__init__.py b/src/sparsetensors/config/__init__.py index 6465c3c6..ff83f5af 100644 --- a/src/sparsetensors/config/__init__.py +++ b/src/sparsetensors/config/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # flake8: noqa - -from .base import CompressionConfig -from .dense import DenseSparsityConfig -from .sparse_bitmask import BitmaskConfig +from .base import * +from .dense import * +from .sparse_bitmask import * diff --git a/src/sparsetensors/config/base.py b/src/sparsetensors/config/base.py index ddf75962..7f428fc7 100644 --- a/src/sparsetensors/config/base.py +++ b/src/sparsetensors/config/base.py @@ -15,8 +15,7 @@ from typing import Optional from pydantic import BaseModel -from sparsezoo.utils.registry import ModuleSparsificationInfo, RegistryMixin -from torch.nn import Module +from sparsezoo.utils.registry import RegistryMixin __all__ = ["CompressionConfig"] @@ -35,55 +34,3 @@ class CompressionConfig(RegistryMixin, BaseModel): format: str global_sparsity: Optional[float] = 0.0 sparsity_structure: Optional[str] = "unstructured" - - @staticmethod - def infer_global_sparsity(model: Module) -> float: - """ - Calculates the global percentage of sparse zero weights in the model - - :param model: pytorch model to infer sparsity of - :return: global sparsity of model - """ - info = ModuleSparsificationInfo(model) - global_sparsity = info.params_sparse_percent - return global_sparsity - - # TODO: Move infer_sparsity_structure to sparseml - - @staticmethod - def infer_config_from_model( - model: Module, compress: bool = False - ) -> Optional["CompressionConfig"]: - """ - Determines compression type and informational parameters for a given model - - :param model: pytorch model to calculate sparsity config for - :param compress: whether or not to compress the model on disk - :return: compression config inferred from the model - """ - - global_sparsity = CompressionConfig.infer_global_sparsity(model) - - if global_sparsity < 0.05: - return None - - sparsity_structure = CompressionConfig.infer_sparsity_structure() - if compress: - format = "sparse_bitmask" - else: - format = "dense_sparsity" - - return CompressionConfig.load_from_registry( - format, - global_sparsity=global_sparsity, - sparsity_structure=sparsity_structure, - ) - - def fill_config_details(self, model: Module): - """ - Fills in informational sparsity parameters from a given model - - :param model: pytorch model to infer config parameters from - """ - self.global_sparsity = CompressionConfig.infer_global_sparsity(model) - self.sparsity_structure = CompressionConfig.infer_sparsity_structure() diff --git a/src/sparsetensors/config/sparse_bitmask.py b/src/sparsetensors/config/sparse_bitmask.py index c5663a75..d17c6a1a 100644 --- a/src/sparsetensors/config/sparse_bitmask.py +++ b/src/sparsetensors/config/sparse_bitmask.py @@ -14,7 +14,7 @@ from typing import Optional -from sparsetensors.config import CompressionConfig +from sparsetensors.config.base import CompressionConfig __all__ = ["BitmaskConfig"] diff --git a/src/sparsetensors/utils/__init__.py b/src/sparsetensors/utils/__init__.py index 99c02863..5bc0fec2 100644 --- a/src/sparsetensors/utils/__init__.py +++ b/src/sparsetensors/utils/__init__.py @@ -13,5 +13,4 @@ # limitations under the License. # flake8: noqa -from .compress_save import * from .safetensors_load import * diff --git a/src/sparsetensors/utils/compress_save.py b/src/sparsetensors/utils/compress_save.py deleted file mode 100644 index 8fcb657a..00000000 --- a/src/sparsetensors/utils/compress_save.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -import weakref -from functools import wraps -from typing import Optional - -from sparsetensors.compressors import ModelCompressor -from sparsetensors.config import CompressionConfig -from transformers import PreTrainedModel -from transformers.file_utils import CONFIG_NAME - -from . import SPARSITY_CONFIG_NAME - - -_LOGGER = logging.getLogger(__name__) - -__all__ = ["modify_save_pretrained"] - - -def modify_save_pretrained(model: PreTrainedModel): - """ - Overrides a PreTrainedModel's save_pretrained() method with a wrapped version that - supports compression - """ - - def save_pretrained_compressed(save_pretrained_method): - if getattr(save_pretrained_method, "_overridden", False): - # `model.save_pretrained` has already been replaced, return. - return save_pretrained_method - - # Keep a weak reference to the model class and unbound save_pretrained - # method so we can call the original - model_ref = weakref.ref(save_pretrained_method.__self__) - original_save_pretrained = save_pretrained_method.__func__ - model_class = model_ref().__class__ - del save_pretrained_method - - @wraps(original_save_pretrained) - def save_pretrained_wrapper( - save_directory: str, - sparsity_config: Optional[CompressionConfig] = None, - save_compressed: bool = False, - skip_compression_stats: bool = False, - **kwargs, - ): - """ - Wrapper around PreTrainedModel.save_pretrained(), adds functionality for - saving models in a compressed format on disk. The compression format is - saved to the model's config file - - :param save_directory: output directory to save model to - :param sparsity_config: optional sparsity config to compress model with, - if no config is provided it will be inferred from the model - :param save_compresed: whether or not to compress the model on disk - :param skip_compression_stats: whether to skip the calculation of - compression statistics (such as global sparsity and sparsity structure) when - saving a model in dense format - :param kwargs: additional kwargs to pass on to model.save_pretrained - """ - model = model_ref() - - if qat_active(model): # noqa TODO - _LOGGER.info( - "Compression for quantized models is not yet supported. Save will " - "be run without compression and no sparsity statistics will be " - "calculated." - ) - return original_save_pretrained.__get__(model, model_class)( - save_directory, **kwargs - ) - - if sparsity_config is not None: - sparsity_config.fill_config_details(model) - elif not skip_compression_stats: - # try to infer a sparsity config from the model if none is provided - _LOGGER.info( - "Inferring a sparsity configuration requires a global sparsity " - "calculation. This can be costly for large models. To skip the " - "calculation of compression statistics set " - "skip_compression_stats=True" - ) - sparsity_config = CompressionConfig.infer_config_from_model( - model, compress=save_compressed - ) - - if sparsity_config is None: - # model is not sparse, save as dense - return original_save_pretrained.__get__(model, model_class)( - save_directory, **kwargs - ) - - # if we've gotten to this point we have a config so we can run compression - kwargs["safe_serialization"] = True - compressor = ModelCompressor.load_from_registry( - sparsity_config.format, config=sparsity_config - ) - - # state_dict gets passed in as a kwarg for FSDP models - state_dict = kwargs.get("state_dict", None) - if state_dict is None: - state_dict = model.state_dict() - - # make sure we're on the main process when saving - if state_dict is not None and len(state_dict) > 0: - compressed_state_dict = compressor.compress(state_dict) - kwargs["state_dict"] = compressed_state_dict - - original_save_pretrained.__get__(model, model_class)( - save_directory, **kwargs - ) - sparsity_config_data = sparsity_config.dict() - config_file_path = os.path.join(save_directory, CONFIG_NAME) - - # add the sparsity config to the model's config file - with open(config_file_path, "r") as config_file: - config_data = json.load(config_file) - config_data[SPARSITY_CONFIG_NAME] = sparsity_config_data - with open(config_file_path, "w") as config_file: - json.dump(config_data, config_file, indent=2, sort_keys=True) - - save_pretrained_wrapper._overriden = True - return save_pretrained_wrapper - - # wrap save_pretrained - model.save_pretrained = save_pretrained_compressed(model.save_pretrained) diff --git a/src/sparsetensors/utils/safetensors_load.py b/src/sparsetensors/utils/safetensors_load.py index 7defda7e..c82d3e43 100644 --- a/src/sparsetensors/utils/safetensors_load.py +++ b/src/sparsetensors/utils/safetensors_load.py @@ -16,13 +16,14 @@ import os import re import struct -from typing import Dict, List +from typing import Dict, List, Optional -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file __all__ = [ "get_safetensors_header", + "get_safetensors_folder", "match_param_name", "merge_names", "get_weight_mappings", @@ -147,3 +148,45 @@ def get_nested_weight_mappings( nested_weight_mappings[dense_param][param_name] = weight_mappings[key] return nested_weight_mappings + + +def get_safetensors_folder( + pretrained_model_name_or_path: str, cache_dir: Optional[str] = None +) -> str: + """ + Given a Hugging Face stub or a local path, return the folder containing the + safetensors weight files + + :param pretrained_model_name_or_path: local path to model or HF stub + :param cache_dir: optional cache dir to search through, if none is specified the + model will be searched for in the default TRANSFORMERS_CACHE + :return: local folder containing model data + """ + if os.path.exists(pretrained_model_name_or_path): + # argument is a path to a local folder + return pretrained_model_name_or_path + + safetensors_path = cached_file( + pretrained_model_name_or_path, + SAFE_WEIGHTS_NAME, + cache_dir=cache_dir, + _raise_exceptions_for_missing_entries=False, + ) + index_path = cached_file( + pretrained_model_name_or_path, + SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + _raise_exceptions_for_missing_entries=False, + ) + if safetensors_path is not None: + # found a single cached safetensors file + return os.path.split(safetensors_path)[0] + if index_path is not None: + # found a cached safetensors weight index file + return os.path.split(index_path)[0] + + # model weights could not be found locally or cached from HF Hub + raise ValueError( + "Could not locate safetensors weight or index file from " + f"{pretrained_model_name_or_path}." + ) diff --git a/tests/test_bitmask.py b/tests/test_bitmask.py index d56124a1..b5bca142 100644 --- a/tests/test_bitmask.py +++ b/tests/test_bitmask.py @@ -18,8 +18,7 @@ import pytest import torch from safetensors.torch import save_file -from sparsetensors import BitmaskCompressor, BitmaskConfig -from sparsetensors.compressors.sparse_bitmask import BitmaskTensor +from sparsetensors import BitmaskCompressor, BitmaskConfig, BitmaskTensor @pytest.mark.parametrize( @@ -33,7 +32,7 @@ ) def test_bitmask_sizes(shape, sparsity, dtype): test_tensor = torch.rand(shape, dtype=dtype) - mask = (test_tensor.abs() < (1 - sparsity)).int() # noqa + mask = (test_tensor.abs() < (1 - sparsity)).int() test_tensor *= mask dense_state_dict = {"dummy.weight": test_tensor} diff --git a/tests/test_registries.py b/tests/test_registry.py similarity index 100% rename from tests/test_registries.py rename to tests/test_registry.py