From f826555cddf590d160cfb8c87f3a52467fa35b9e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Dec 2024 00:56:40 +0000 Subject: [PATCH] replace depreciated pydantic functions --- .../compressors/model_compressors/model_compressor.py | 6 +++--- src/compressed_tensors/quantization/quant_config.py | 2 +- .../model_compressors/test_model_compressor.py | 4 ++-- tests/test_quantization/lifecycle/test_apply.py | 2 +- .../test_quantization/lifecycle/test_dynamic_lifecycle.py | 2 +- tests/test_quantization/test_quant_config.py | 7 +++++++ 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 68bd52ec..297fc3a6 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -137,7 +137,7 @@ def from_compression_config( format, **sparsity_config ) if quantization_config is not None: - quantization_config = QuantizationConfig.parse_obj(quantization_config) + quantization_config = QuantizationConfig.model_validate(quantization_config) return cls( sparsity_config=sparsity_config, quantization_config=quantization_config @@ -193,7 +193,7 @@ def parse_sparsity_config( if is_compressed_tensors_config(compression_config): s_config = compression_config.sparsity_config - return s_config.dict() if s_config is not None else None + return s_config.model_dump() if s_config is not None else None return compression_config.get(SPARSITY_CONFIG_NAME, None) @@ -214,7 +214,7 @@ def parse_quantization_config( if is_compressed_tensors_config(compression_config): q_config = compression_config.quantization_config - return q_config.dict() if q_config is not None else None + return q_config.model_dump() if q_config is not None else None quantization_config = deepcopy(compression_config) quantization_config.pop(SPARSITY_CONFIG_NAME, None) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 1d95aee8..3a80f0cb 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -160,7 +160,7 @@ def model_post_init(self, __context): def to_dict(self): # for compatibility with HFQuantizer - return self.dict() + return self.model_dump() @staticmethod def from_pretrained( diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 3f6940a9..4a8327ce 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -99,8 +99,8 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path): ) q_config = QuantizationConfig(**q_config) if q_config is not None else None - s_config_dict = s_config.dict() if s_config is not None else None - q_config_dict = q_config.dict() if q_config is not None else None + s_config_dict = s_config.model_dump() if s_config is not None else None + q_config_dict = q_config.model_dump() if q_config is not None else None assert compressor.sparsity_config == s_config assert compressor.quantization_config == q_config diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 7268ca27..5ad56b8e 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -222,7 +222,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): }, "ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"], } - return QuantizationConfig.parse_obj(config_dict) + return QuantizationConfig.model_validate(config_dict) @requires_accelerate() diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index dd700637..3ac91e85 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -110,4 +110,4 @@ def get_sample_dynamic_tinyllama_quant_config(): }, "ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"], } - return QuantizationConfig.parse_obj(config_dict) + return QuantizationConfig.model_validate(config_dict) diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index 460db82b..c3830a02 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -72,3 +72,10 @@ def test_load_scheme_from_preset(scheme_name: str): assert scheme_name in config.config_groups assert isinstance(config.config_groups[scheme_name], QuantizationScheme) assert config.config_groups[scheme_name].targets == targets + + +def test_to_dict(): + config_groups = {"group_1": QuantizationScheme(targets=[])} + config = QuantizationConfig(config_groups=config_groups) + reloaded = QuantizationConfig.model_validate(config.to_dict()) + assert config == reloaded