Skip to content

Commit

Permalink
replace depreciated pydantic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 5, 2024
1 parent 2dcbc9d commit f826555
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_compression_config(
format, **sparsity_config
)
if quantization_config is not None:
quantization_config = QuantizationConfig.parse_obj(quantization_config)
quantization_config = QuantizationConfig.model_validate(quantization_config)

return cls(
sparsity_config=sparsity_config, quantization_config=quantization_config
Expand Down Expand Up @@ -193,7 +193,7 @@ def parse_sparsity_config(

if is_compressed_tensors_config(compression_config):
s_config = compression_config.sparsity_config
return s_config.dict() if s_config is not None else None
return s_config.model_dump() if s_config is not None else None

return compression_config.get(SPARSITY_CONFIG_NAME, None)

Expand All @@ -214,7 +214,7 @@ def parse_quantization_config(

if is_compressed_tensors_config(compression_config):
q_config = compression_config.quantization_config
return q_config.dict() if q_config is not None else None
return q_config.model_dump() if q_config is not None else None

quantization_config = deepcopy(compression_config)
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def model_post_init(self, __context):

def to_dict(self):
# for compatibility with HFQuantizer
return self.dict()
return self.model_dump()

@staticmethod
def from_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path):
)
q_config = QuantizationConfig(**q_config) if q_config is not None else None

s_config_dict = s_config.dict() if s_config is not None else None
q_config_dict = q_config.dict() if q_config is not None else None
s_config_dict = s_config.model_dump() if s_config is not None else None
q_config_dict = q_config.model_dump() if q_config is not None else None

assert compressor.sparsity_config == s_config
assert compressor.quantization_config == q_config
Expand Down
2 changes: 1 addition & 1 deletion tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
},
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)
return QuantizationConfig.model_validate(config_dict)


@requires_accelerate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ def get_sample_dynamic_tinyllama_quant_config():
},
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)
return QuantizationConfig.model_validate(config_dict)
7 changes: 7 additions & 0 deletions tests/test_quantization/test_quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,10 @@ def test_load_scheme_from_preset(scheme_name: str):
assert scheme_name in config.config_groups
assert isinstance(config.config_groups[scheme_name], QuantizationScheme)
assert config.config_groups[scheme_name].targets == targets


def test_to_dict():
config_groups = {"group_1": QuantizationScheme(targets=[])}
config = QuantizationConfig(config_groups=config_groups)
reloaded = QuantizationConfig.model_validate(config.to_dict())
assert config == reloaded

0 comments on commit f826555

Please sign in to comment.