diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index e5725189..ac15fdaa 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -30,7 +30,7 @@ QUANTIZATION_METHOD_NAME, SPARSITY_CONFIG_NAME, ) -from compressed_tensors.compressors import Compressor +from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, @@ -247,11 +247,11 @@ def __init__( self.sparsity_config = None if sparsity_config is not None: - self.sparsity_compressor = Compressor.load_from_registry( + self.sparsity_compressor = BaseCompressor.load_from_registry( sparsity_config.format, config=sparsity_config ) if quantization_config is not None: - self.quantization_compressor = Compressor.load_from_registry( + self.quantization_compressor = BaseCompressor.load_from_registry( quantization_config.format, config=quantization_config ) @@ -262,7 +262,7 @@ def compress( Compresses a dense state dict or model with sparsity and/or quantization :param model: uncompressed model to compress - :param model_state: optional uncompressed state_dict to insert into model + :param state_dict: optional uncompressed state_dict to insert into model :return: compressed state dict """ if state_dict is None: @@ -393,4 +393,4 @@ def new_dtype_byte_size(dtype): if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) - return bit_size // 8 \ No newline at end of file + return bit_size // 8 diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 00c1641f..3f6940a9 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -19,7 +19,6 @@ from compressed_tensors.config.base import SparsityCompressionConfig from compressed_tensors.quantization.quant_config import QuantizationConfig from tests.testing_utils import requires_hf_quantizer -from compressed_tensors.compressors.model_compressors import ModelCompressor def sparsity_config():