Skip to content

Commit

Permalink
some more migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 8, 2024
1 parent 4f606f4 commit 8008cf5
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 282 deletions.
5 changes: 3 additions & 2 deletions src/sparsetensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
SPARSITY_CONFIG_NAME = "sparsity_config"
from .base import *

# flake8: noqa
from .compressors import *
from .config import *
from .utils import *
15 changes: 15 additions & 0 deletions src/sparsetensors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SPARSITY_CONFIG_NAME = "sparsity_config"
2 changes: 1 addition & 1 deletion src/sparsetensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@

from .base import ModelCompressor
from .dense import DenseCompressor
from .sparse_bitmask import BitmaskCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
24 changes: 14 additions & 10 deletions src/sparsetensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import operator
from typing import Dict, Generator, Tuple

from sparsetensors.base import SPARSITY_CONFIG_NAME
from sparsetensors.config import CompressionConfig
from sparsezoo.utils.registry import RegistryMixin
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm

from . import SPARSITY_CONFIG_NAME


__all__ = ["ModelCompressor"]

Expand All @@ -33,7 +33,7 @@ class ModelCompressor(RegistryMixin):
:param config: config specifying compression parameters
"""

def __init__(self, config: "CompressionConfig"): # noqa
def __init__(self, config: CompressionConfig):
self.config = config

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
Expand Down Expand Up @@ -66,17 +66,21 @@ def replace_layer(param_name: str, data: Tensor, model: Module):
:param model: pytorch model to insert data into
"""
model_device = operator.attrgetter(param_name)(model).device
set_layer(param_name, Parameter(data.to(model_device)), model) # noqa TODO

def overwrite_weights(self, pretrained_model_name_or_path: str, model: Module):
new_param = Parameter(data.to(model_device))
# TODO: Two for loops?
for name, param in model.named_parameters():
if name == param_name:
param.data = new_param.data
return

def overwrite_weights(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from
pretrained_model_name_or_path
Overwrites the weights in model with weights decompressed from model_path
:param pretrained_model_name_or_path: path to compressed weights
:param model_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
dense_gen = self.decompress(pretrained_model_name_or_path)
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
ModelCompressor.replace_layer(name, data, model)
setattr(model, SPARSITY_CONFIG_NAME, self.config)
64 changes: 0 additions & 64 deletions src/sparsetensors/compressors/utils/helpers.py

This file was deleted.

7 changes: 3 additions & 4 deletions src/sparsetensors/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

# flake8: noqa

from .base import CompressionConfig
from .dense import DenseSparsityConfig
from .sparse_bitmask import BitmaskConfig
from .base import *
from .dense import *
from .sparse_bitmask import *
55 changes: 1 addition & 54 deletions src/sparsetensors/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import Optional

from pydantic import BaseModel
from sparsezoo.utils.registry import ModuleSparsificationInfo, RegistryMixin
from torch.nn import Module
from sparsezoo.utils.registry import RegistryMixin


__all__ = ["CompressionConfig"]
Expand All @@ -35,55 +34,3 @@ class CompressionConfig(RegistryMixin, BaseModel):
format: str
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "unstructured"

@staticmethod
def infer_global_sparsity(model: Module) -> float:
"""
Calculates the global percentage of sparse zero weights in the model
:param model: pytorch model to infer sparsity of
:return: global sparsity of model
"""
info = ModuleSparsificationInfo(model)
global_sparsity = info.params_sparse_percent
return global_sparsity

# TODO: Move infer_sparsity_structure to sparseml

@staticmethod
def infer_config_from_model(
model: Module, compress: bool = False
) -> Optional["CompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
:param model: pytorch model to calculate sparsity config for
:param compress: whether or not to compress the model on disk
:return: compression config inferred from the model
"""

global_sparsity = CompressionConfig.infer_global_sparsity(model)

if global_sparsity < 0.05:
return None

sparsity_structure = CompressionConfig.infer_sparsity_structure()
if compress:
format = "sparse_bitmask"
else:
format = "dense_sparsity"

return CompressionConfig.load_from_registry(
format,
global_sparsity=global_sparsity,
sparsity_structure=sparsity_structure,
)

def fill_config_details(self, model: Module):
"""
Fills in informational sparsity parameters from a given model
:param model: pytorch model to infer config parameters from
"""
self.global_sparsity = CompressionConfig.infer_global_sparsity(model)
self.sparsity_structure = CompressionConfig.infer_sparsity_structure()
2 changes: 1 addition & 1 deletion src/sparsetensors/config/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Optional

from sparsetensors.config import CompressionConfig
from sparsetensors.config.base import CompressionConfig


__all__ = ["BitmaskConfig"]
Expand Down
1 change: 0 additions & 1 deletion src/sparsetensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@
# limitations under the License.
# flake8: noqa

from .compress_save import *
from .safetensors_load import *
140 changes: 0 additions & 140 deletions src/sparsetensors/utils/compress_save.py

This file was deleted.

Loading

0 comments on commit 8008cf5

Please sign in to comment.