From 27011f633dea8c8a4ef0308fbe12191fd10c1492 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 2 Oct 2024 19:58:47 +0000 Subject: [PATCH] Review Comments from @mgoin --- src/compressed_tensors/compressors/base.py | 5 ++++- .../compressors/base_quantization_compressor.py | 2 +- .../compressors/base_sparsity_compressor.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index d1604fab..f63cab37 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import Dict, Generator, Optional, Tuple, Union import torch @@ -25,7 +26,7 @@ __all__ = ["BaseCompressor"] -class BaseCompressor(RegistryMixin): +class BaseCompressor(RegistryMixin, ABC): """ Base class representing a model compression algorithm. Each child class should implement compression_param_info, compress_weight and decompress_weight. @@ -76,6 +77,7 @@ def compression_param_info( """ raise NotImplementedError() + @abstractmethod def compress( self, model_state: Dict[str, Tensor], @@ -90,6 +92,7 @@ def compress( """ raise NotImplementedError() + @abstractmethod def decompress( self, path_to_model_or_tensors: str, diff --git a/src/compressed_tensors/compressors/base_quantization_compressor.py b/src/compressed_tensors/compressors/base_quantization_compressor.py index e039d6c4..563ecf56 100644 --- a/src/compressed_tensors/compressors/base_quantization_compressor.py +++ b/src/compressed_tensors/compressors/base_quantization_compressor.py @@ -55,7 +55,7 @@ class BaseQuantizationCompressor(BaseCompressor): - Compressor.compression_param_info() - register_parameters() - compressed_module.forward() - -compressed_module.decompress() + - compressed_module.decompress() :param config: config specifying compression parameters diff --git a/src/compressed_tensors/compressors/base_sparsity_compressor.py b/src/compressed_tensors/compressors/base_sparsity_compressor.py index 15087cd8..f7ebbc2f 100644 --- a/src/compressed_tensors/compressors/base_sparsity_compressor.py +++ b/src/compressed_tensors/compressors/base_sparsity_compressor.py @@ -53,7 +53,7 @@ class BaseSparseCompressor(BaseCompressor): - Compressor.compression_param_info() - register_parameters() - compressed_module.forward() - -compressed_module.decompress() + - compressed_module.decompress() :param config: config specifying compression parameters