From 6ae1b94de2326839f63d83124147b1acb41bb4c4 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 1 May 2024 13:53:47 -0400 Subject: [PATCH] [BugFix] Fix Serialization of Computed Properties in BaseModel (#485) * Add failing test * Fix computed field serialization --- src/sparsezoo/analyze_v1/analysis.py | 13 +++- src/sparsezoo/analyze_v1/utils/models.py | 63 +++---------------- .../analyze/test_analysis/test_models.py | 27 ++++++++ 3 files changed, 44 insertions(+), 59 deletions(-) create mode 100644 tests/sparsezoo/analyze/test_analysis/test_models.py diff --git a/src/sparsezoo/analyze_v1/analysis.py b/src/sparsezoo/analyze_v1/analysis.py index 1578cbd2..7214444e 100644 --- a/src/sparsezoo/analyze_v1/analysis.py +++ b/src/sparsezoo/analyze_v1/analysis.py @@ -103,6 +103,13 @@ class YAMLSerializableBaseModel(BaseModel): A BaseModel that adds a .yaml(...) function to all child classes """ + model_config = ConfigDict(protected_namespaces=()) + + def dict(self, *args, **kwargs) -> Dict[str, Any]: + # alias for model_dump for pydantic v2 upgrade + # to allow for easier migration + return self.model_dump(*args, **kwargs) + def yaml(self, file_path: Optional[str] = None) -> Union[str, None]: """ :param file_path: optional file path to save yaml to @@ -111,7 +118,7 @@ def yaml(self, file_path: Optional[str] = None) -> Union[str, None]: """ file_stream = None if file_path is None else open(file_path, "w") ret = yaml.dump( - self.dict(), stream=file_stream, allow_unicode=True, sort_keys=False + self.model_dump(), stream=file_stream, allow_unicode=True, sort_keys=False ) if file_stream is not None: @@ -127,7 +134,7 @@ def parse_yaml_file(cls, file_path: str): """ with open(file_path, "r") as file: dict_obj = yaml.safe_load(file) - return cls.parse_obj(dict_obj) + return cls.model_validate(dict_obj) @classmethod def parse_yaml_raw(cls, yaml_raw: str): @@ -136,7 +143,7 @@ def parse_yaml_raw(cls, yaml_raw: str): :return: instance of ModelAnalysis class """ dict_obj = yaml.safe_load(yaml_raw) # unsafe: needs to load numpy - return cls.parse_obj(dict_obj) + return cls.model_validate(dict_obj) @dataclass diff --git a/src/sparsezoo/analyze_v1/utils/models.py b/src/sparsezoo/analyze_v1/utils/models.py index 68c4d7f5..913b6fc4 100644 --- a/src/sparsezoo/analyze_v1/utils/models.py +++ b/src/sparsezoo/analyze_v1/utils/models.py @@ -15,7 +15,7 @@ import textwrap from typing import ClassVar, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field __all__ = [ @@ -33,58 +33,6 @@ PrintOrderType = ClassVar[List[str]] -class PropertyBaseModel(BaseModel): - """ - https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432 - - Workaround for serializing properties with pydantic until - https://github.com/samuelcolvin/pydantic/issues/935 - is solved - """ - - @classmethod - def get_properties(cls): - return [ - prop - for prop in dir(cls) - if isinstance(getattr(cls, prop), property) - and prop not in ("__values__", "fields") - ] - - def dict( - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821 - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821 - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> "DictStrAny": # noqa: F821 - attribs = super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - props = self.get_properties() - # Include and exclude properties - if include: - props = [prop for prop in props if prop in include] - if exclude: - props = [prop for prop in props if prop not in exclude] - - # Update the attribute dict with the properties - if props: - attribs.update({prop: getattr(self, prop) for prop in props}) - - return attribs - - class NodeCounts(BaseModel): """ Pydantic model for specifying the number zero and non-zero operations and the @@ -114,7 +62,7 @@ class NodeIO(BaseModel): ) -class ZeroNonZeroParams(PropertyBaseModel): +class ZeroNonZeroParams(BaseModel): """ Pydantic model for specifying the number zero and non-zero operations and the associated sparsity @@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel): description="The number of parameters whose value is zero", default=0 ) + @computed_field(repr=True, return_type=Union[int, float]) @property def sparsity(self): total_values = self.total if total_values > 0: return self.zero / total_values else: - return 0 + return 0.0 + @computed_field(repr=True, return_type=int) @property def total(self): return self.non_zero + self.zero -class DenseSparseOps(PropertyBaseModel): +class DenseSparseOps(BaseModel): """ Pydantic model for specifying the number dense and sparse operations and the associated operation sparsity @@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel): default=0, ) + @computed_field(repr=True, return_type=Union[int, float]) @property def sparsity(self): total_ops = self.sparse + self.dense diff --git a/tests/sparsezoo/analyze/test_analysis/test_models.py b/tests/sparsezoo/analyze/test_analysis/test_models.py new file mode 100644 index 00000000..ace7dd59 --- /dev/null +++ b/tests/sparsezoo/analyze/test_analysis/test_models.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from sparsezoo.analyze_v1.utils.models import DenseSparseOps, ZeroNonZeroParams + + +@pytest.mark.parametrize("model", [DenseSparseOps, ZeroNonZeroParams]) +@pytest.mark.parametrize("computed_fields", [["sparsity"]]) +def test_model_dump_has_computed_fields(model, computed_fields): + model = model() + model_dict = model.model_dump() + for computed_field in computed_fields: + assert computed_field in model_dict + assert model_dict[computed_field] == getattr(model, computed_field)