Skip to content

Commit

Permalink
Merge branch 'main' of github.com:neuralmagic/compressed-tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Oct 4, 2024
2 parents 3fbb73e + 0067091 commit 065fd90
Show file tree
Hide file tree
Showing 37 changed files with 671 additions and 279 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
wf_category:
description: "workflow category: NIGHTLY, RELEASE"
type: string
default: RELEASE
default: NIGHTLY
push_to_pypi:
description: "When set to true, built whl and tar.gz will be pushed to public pypi if all tests pass"
type: boolean
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:

- name: build
id: build
uses: neuralmagic/nm-actions/actions/build-ml-whl@v1.2.0
uses: neuralmagic/nm-actions/actions/build-ml-whl@v1.6.0
with:
dev: false
release: ${{ inputs.wf_category == 'RELEASE' }}
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/trigger-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ on:
workflow_dispatch:
inputs:
wf_category:
description: "workflow category, default is RELEASE"
description: "workflow category, default is NIGHTLY"
type: choice
options:
- NIGHTLY
- RELEASE
default: RELEASE
default: NIGHTLY
push_to_pypi:
description: "when set and tests pass, then '.whl' & '.tar.gz' will be pushed to public pypi"
type: boolean
Expand All @@ -27,11 +27,11 @@ jobs:

BUILD-TEST:
uses: ./.github/workflows/build-test.yml
name: ${{ inputs.wf_category || 'RELEASE' }}
name: ${{ inputs.wf_category || 'NIGHTLY' }}
with:
wf_category: ${{ inputs.wf_category || 'RELEASE' }}
wf_category: ${{ inputs.wf_category || 'NIGHTLY' }}
gitref: ${{ inputs.gitref || 'main' }}
push_to_pypi: ${{ inputs.push_to_pypi || false }}
push_to_pypi: ${{ (github.event.schedule == '30 0 * * *') || inputs.push_to_pypi || false }}
test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"},
{"python":"3.10.12","label":"ubuntu-20.04","timeout":"40"},
{"python":"3.9.17","label":"k8s-a100-solo","timeout":"40"},
Expand Down
1 change: 1 addition & 0 deletions src/compressed_tensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
COMPRESSION_CONFIG_NAME = "compression_config"
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
COMPRESSION_VERSION_NAME = "version"
QUANTIZATION_METHOD_NAME = "quant_method"
18 changes: 6 additions & 12 deletions src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

# flake8: noqa

from .base import Compressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .naive_quantized import (
FloatQuantizationCompressor,
IntQuantizationCompressor,
QuantizationCompressor,
)
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
from .base import *
from .helpers import *
from .model_compressors import *
from .quantized_compressors import *
from .sparse_compressors import *
from .sparse_quantized_compressors import *
140 changes: 38 additions & 102 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import ABC, abstractmethod
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from safetensors import safe_open
from torch import Tensor
from torch.nn.modules import Module
from tqdm import tqdm
from torch.nn import Module


_LOGGER: logging.Logger = logging.getLogger(__name__)
__all__ = ["BaseCompressor"]

__all__ = ["Compressor"]


class Compressor(RegistryMixin):
class BaseCompressor(RegistryMixin, ABC):
"""
Base class representing a model compression algorithm. Each child class should
implement compression_param_info, compress_weight and decompress_weight.
Expand All @@ -42,19 +37,18 @@ class Compressor(RegistryMixin):
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- Compressor.decompress_weight()
- BaseCompressor.decompress()
Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- Compressor.compress_weight()
- BaseCompressor.compress()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- BaseCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
-compressed_module.decompress()
Expand Down Expand Up @@ -83,61 +77,27 @@ def compression_param_info(
"""
raise NotImplementedError()

@abstractmethod
def compress(
self,
model_state: Dict[str, Tensor],
names_to_scheme: Dict[str, QuantizationArgs],
**kwargs,
) -> Dict[str, Tensor]:
"""
Compresses a dense state dict
:param model_state: state dict of uncompressed model
:param names_to_scheme: quantization args for each quantized weight, needed for
quantize function to calculate bit depth
:param kwargs: additional arguments for compression
:return: compressed state dict
"""
compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

for name, value in tqdm(model_state.items(), desc="Compressing model"):
if name.endswith(weight_suffix):
prefix = name[: -(len(weight_suffix))]
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
if scale is not None:
# weight is quantized, compress it
quant_args = names_to_scheme[prefix]
compressed_data = self.compress_weight(
weight=value,
scale=scale,
zero_point=zp,
g_idx=g_idx,
quantization_args=quant_args,
device="cpu",
)
for key, value in compressed_data.items():
compressed_dict[merge_names(prefix, key)] = value
else:
compressed_dict[name] = value.to("cpu")
elif name.endswith("zero_point") and torch.all(value == 0):
continue
elif name.endswith("g_idx") and torch.any(value <= -1):
continue
else:
compressed_dict[name] = value.to("cpu")

return compressed_dict
raise NotImplementedError()

@abstractmethod
def decompress(
self,
path_to_model_or_tensors: str,
names_to_scheme: Dict[str, QuantizationArgs],
device: str = "cpu",
**kwargs,
) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at path_to_model_or_tensors
Expand All @@ -150,55 +110,6 @@ def decompress(
:param device: optional device to load intermediate weights into
:return: compressed state dict
"""
weight_mappings = get_nested_weight_mappings(
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device=device) as f:
weight_data[param_name] = f.get_tensor(full_name)

if "weight_scale" in weight_data:
quant_args = names_to_scheme[weight_name]
decompressed = self.decompress_weight(
compressed_data=weight_data, quantization_args=quant_args
)
yield merge_names(weight_name, "weight"), decompressed

def compress_weight(
self,
weight: Tensor,
scale: Tensor,
zero_point: Optional[Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
quantization_args: Optional[QuantizationArgs] = None,
) -> Dict[str, torch.Tensor]:
"""
Compresses a single uncompressed weight
:param weight: uncompressed weight tensor
:param scale: quantization scale for weight
:param zero_point: quantization zero point for weight
:param g_idx: optional mapping from column index to group index
:param quantization_args: quantization parameters for weight
:return: dictionary of compressed weight data
"""
raise NotImplementedError()

def decompress_weight(
self,
compressed_data: Dict[str, Tensor],
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:
"""
Decompresses a single compressed weight
:param compressed_data: dictionary of data needed for decompression
:param quantization_args: quantization parameters for the weight
:return: tensor of the decompressed weight
"""
raise NotImplementedError()

def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
Expand Down Expand Up @@ -228,6 +139,19 @@ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
quantization_args=quantization_args,
)

def compress_weight(
self,
weight: Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Compresses a single uncompressed weight
:param weight: uncompressed weight tensor
:param kwargs: additional arguments for compression
"""
raise NotImplementedError()

def decompress_module(self, module: Module):
"""
Decompresses a single compressed leaf PyTorch module. If the module is not
Expand All @@ -250,3 +174,15 @@ def decompress_module(self, module: Module):
return self.decompress_weight(
compressed_data=compressed_data, quantization_args=quantization_args
)

def decompress_weight(
self, compressed_data: Dict[str, Tensor], **kwargs
) -> torch.Tensor:
"""
Decompresses a single compressed weight
:param compressed_data: dictionary of data needed for decompression
:param kwargs: additional arguments for decompression
:return: tensor of the decompressed weight
"""
raise NotImplementedError()
12 changes: 6 additions & 6 deletions src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors import Compressor
from compressed_tensors.compressors import BaseCompressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
Expand Down Expand Up @@ -52,16 +52,16 @@ def save_compressed(
compression_format = compression_format or CompressionFormat.dense.value

if not (
compression_format in Compressor.registered_names()
or compression_format in Compressor.registered_aliases()
compression_format in BaseCompressor.registered_names()
or compression_format in BaseCompressor.registered_aliases()
):
raise ValueError(
f"Unknown compression format: {compression_format}. "
f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501
)

# compress
compressor = Compressor.load_from_registry(compression_format)
compressor = BaseCompressor.load_from_registry(compression_format)
# save compressed tensors
compressed_tensors = compressor.compress(tensors)
save_file(compressed_tensors, save_path)
Expand Down Expand Up @@ -102,7 +102,7 @@ def load_compressed(
else:
# decompress tensors
compression_format = compression_config.format
compressor = Compressor.load_from_registry(
compressor = BaseCompressor.load_from_registry(
compression_format, config=compression_config
)
yield from compressor.decompress(compressed_tensors, device=device)
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/compressors/model_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa


from .model_compressor import *
Loading

0 comments on commit 065fd90

Please sign in to comment.