Skip to content

Commit

Permalink
[BugFix] Fix Serialization of Computed Properties in BaseModel (#485)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix computed field serialization
  • Loading branch information
rahul-tuli authored May 1, 2024
1 parent f349af4 commit 6ae1b94
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 59 deletions.
13 changes: 10 additions & 3 deletions src/sparsezoo/analyze_v1/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class YAMLSerializableBaseModel(BaseModel):
A BaseModel that adds a .yaml(...) function to all child classes
"""

model_config = ConfigDict(protected_namespaces=())

def dict(self, *args, **kwargs) -> Dict[str, Any]:
# alias for model_dump for pydantic v2 upgrade
# to allow for easier migration
return self.model_dump(*args, **kwargs)

def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
"""
:param file_path: optional file path to save yaml to
Expand All @@ -111,7 +118,7 @@ def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
"""
file_stream = None if file_path is None else open(file_path, "w")
ret = yaml.dump(
self.dict(), stream=file_stream, allow_unicode=True, sort_keys=False
self.model_dump(), stream=file_stream, allow_unicode=True, sort_keys=False
)

if file_stream is not None:
Expand All @@ -127,7 +134,7 @@ def parse_yaml_file(cls, file_path: str):
"""
with open(file_path, "r") as file:
dict_obj = yaml.safe_load(file)
return cls.parse_obj(dict_obj)
return cls.model_validate(dict_obj)

@classmethod
def parse_yaml_raw(cls, yaml_raw: str):
Expand All @@ -136,7 +143,7 @@ def parse_yaml_raw(cls, yaml_raw: str):
:return: instance of ModelAnalysis class
"""
dict_obj = yaml.safe_load(yaml_raw) # unsafe: needs to load numpy
return cls.parse_obj(dict_obj)
return cls.model_validate(dict_obj)


@dataclass
Expand Down
63 changes: 7 additions & 56 deletions src/sparsezoo/analyze_v1/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import textwrap
from typing import ClassVar, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, computed_field


__all__ = [
Expand All @@ -33,58 +33,6 @@
PrintOrderType = ClassVar[List[str]]


class PropertyBaseModel(BaseModel):
"""
https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432
Workaround for serializing properties with pydantic until
https://github.com/samuelcolvin/pydantic/issues/935
is solved
"""

@classmethod
def get_properties(cls):
return [
prop
for prop in dir(cls)
if isinstance(getattr(cls, prop), property)
and prop not in ("__values__", "fields")
]

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny": # noqa: F821
attribs = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
props = self.get_properties()
# Include and exclude properties
if include:
props = [prop for prop in props if prop in include]
if exclude:
props = [prop for prop in props if prop not in exclude]

# Update the attribute dict with the properties
if props:
attribs.update({prop: getattr(self, prop) for prop in props})

return attribs


class NodeCounts(BaseModel):
"""
Pydantic model for specifying the number zero and non-zero operations and the
Expand Down Expand Up @@ -114,7 +62,7 @@ class NodeIO(BaseModel):
)


class ZeroNonZeroParams(PropertyBaseModel):
class ZeroNonZeroParams(BaseModel):
"""
Pydantic model for specifying the number zero and non-zero operations and the
associated sparsity
Expand All @@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel):
description="The number of parameters whose value is zero", default=0
)

@computed_field(repr=True, return_type=Union[int, float])
@property
def sparsity(self):
total_values = self.total
if total_values > 0:
return self.zero / total_values
else:
return 0
return 0.0

@computed_field(repr=True, return_type=int)
@property
def total(self):
return self.non_zero + self.zero


class DenseSparseOps(PropertyBaseModel):
class DenseSparseOps(BaseModel):
"""
Pydantic model for specifying the number dense and sparse operations and the
associated operation sparsity
Expand All @@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel):
default=0,
)

@computed_field(repr=True, return_type=Union[int, float])
@property
def sparsity(self):
total_ops = self.sparse + self.dense
Expand Down
27 changes: 27 additions & 0 deletions tests/sparsezoo/analyze/test_analysis/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sparsezoo.analyze_v1.utils.models import DenseSparseOps, ZeroNonZeroParams


@pytest.mark.parametrize("model", [DenseSparseOps, ZeroNonZeroParams])
@pytest.mark.parametrize("computed_fields", [["sparsity"]])
def test_model_dump_has_computed_fields(model, computed_fields):
model = model()
model_dict = model.model_dump()
for computed_field in computed_fields:
assert computed_field in model_dict
assert model_dict[computed_field] == getattr(model, computed_field)

0 comments on commit 6ae1b94

Please sign in to comment.