Skip to content

Commit

Permalink
Update compressors folder structure (#166)
Browse files Browse the repository at this point in the history
* Update folder structure
Move tests

Remove unused import

* Apply suggestions from code review

Typos caught by @mgoin

Co-authored-by: Michael Goin <michael@neuralmagic.com>

---------

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
rahul-tuli and mgoin authored Oct 3, 2024
1 parent 1f6a056 commit 4da7887
Show file tree
Hide file tree
Showing 25 changed files with 151 additions and 36 deletions.
18 changes: 6 additions & 12 deletions src/compressed_tensors/compressors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@

# flake8: noqa

from .base import BaseCompressor
from .dense import DenseCompressor
from .helpers import load_compressed, save_compressed, save_compressed_model
from .marlin_24 import Marlin24Compressor
from .model_compressor import ModelCompressor, map_modules_to_quant_args
from .naive_quantized import (
FloatQuantizationCompressor,
IntQuantizationCompressor,
QuantizationCompressor,
)
from .pack_quantized import PackedQuantizationCompressor
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
from .base import *
from .helpers import *
from .model_compressors import *
from .quantized_compressors import *
from .sparse_compressors import *
from .sparse_quantized_compressors import *
6 changes: 3 additions & 3 deletions src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ class BaseCompressor(RegistryMixin, ABC):
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- BaseCompressor.decompress()
Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- BaseCompressor.compress()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- BaseCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
-compressed_module.decompress()
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/compressors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, Generator, Optional, Tuple, Union

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors import BaseCompressor
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.utils.safetensors_load import get_weight_mappings
from safetensors import safe_open
Expand Down
17 changes: 17 additions & 0 deletions src/compressed_tensors/compressors/model_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa


from .model_compressor import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .base import *
from .naive_quantized import *
from .pack_quantized import *
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from tqdm import tqdm


__all__ = ["BaseQuantizationCompressor"]

_LOGGER: logging.Logger = logging.getLogger(__name__)

__all__ = ["BaseQuantizationCompressor"]


class BaseQuantizationCompressor(BaseCompressor):
"""
Expand All @@ -40,19 +40,19 @@ class BaseQuantizationCompressor(BaseCompressor):
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- Compressor.decompress_weight()
- BaseQuantizationCompressor.decompress()
- BaseQuantizationCompressor.decompress_weight()
Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- Compressor.compress_weight()
- BaseQuantizationCompressor.compress()
- BaseQuantizationCompressor.compress_weight()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- BaseQuantizationCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
- compressed_module.decompress()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.base_quantization_compressor import (
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
Expand All @@ -27,14 +27,14 @@


__all__ = [
"QuantizationCompressor",
"NaiveQuantizationCompressor",
"IntQuantizationCompressor",
"FloatQuantizationCompressor",
]


@BaseCompressor.register(name=CompressionFormat.naive_quantized.value)
class QuantizationCompressor(BaseQuantizationCompressor):
class NaiveQuantizationCompressor(BaseQuantizationCompressor):
"""
Implements naive compression for quantized models. Weight of each
quantized layer is converted from its original float type to the closest Pytorch
Expand Down Expand Up @@ -123,7 +123,7 @@ def decompress_weight(


@BaseCompressor.register(name=CompressionFormat.int_quantized.value)
class IntQuantizationCompressor(QuantizationCompressor):
class IntQuantizationCompressor(NaiveQuantizationCompressor):
"""
Alias for integer quantized models
"""
Expand All @@ -132,7 +132,7 @@ class IntQuantizationCompressor(QuantizationCompressor):


@BaseCompressor.register(name=CompressionFormat.float_quantized.value)
class FloatQuantizationCompressor(QuantizationCompressor):
class FloatQuantizationCompressor(NaiveQuantizationCompressor):
"""
Alias for fp quantized models
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.base_quantization_compressor import (
from compressed_tensors.compressors.quantized_compressors.base import (
BaseQuantizationCompressor,
)
from compressed_tensors.config import CompressionFormat
Expand Down
18 changes: 18 additions & 0 deletions src/compressed_tensors/compressors/sparse_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .base import *
from .dense import *
from .sparse_bitmask import *
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ class BaseSparseCompressor(BaseCompressor):
Model Load Lifecycle (run_compressed=False):
- ModelCompressor.decompress()
- apply_quantization_config()
- Compressor.decompress()
- Compressor.decompress_weight()
- BaseSparseCompressor.decompress()
- BaseSparseCompressor.decompress_weight()
Model Save Lifecycle:
- ModelCompressor.compress()
- Compressor.compress()
- Compressor.compress_weight()
- BaseSparseCompressor.compress()
- BaseSparseCompressor.compress_weight()
Module Lifecycle (run_compressed=True):
- apply_quantization_config()
- compressed_module = CompressedLinear(module)
- initialize_module_for_quantization()
- Compressor.compression_param_info()
- BaseSparseCompressor.compression_param_info()
- register_parameters()
- compressed_module.forward()
- compressed_module.decompress()
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.compressors.base_sparsity_compressor import BaseSparseCompressor
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
from compressed_tensors.config import CompressionFormat
from compressed_tensors.utils import merge_names
from torch import Tensor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa

from .marlin_24 import Marlin24Compressor
13 changes: 13 additions & 0 deletions tests/test_compressors/model_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/test_compressors/quantized_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
import torch
from compressed_tensors import PackedQuantizationCompressor
from compressed_tensors.compressors.pack_quantized import (
from compressed_tensors.compressors.quantized_compressors.pack_quantized import (
pack_to_int32,
unpack_from_int32,
)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_compressors/sparse_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
File renamed without changes.
13 changes: 13 additions & 0 deletions tests/test_compressors/sparse_quantized_compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
File renamed without changes.

0 comments on commit 4da7887

Please sign in to comment.