Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix Serialization of Computed Properties in BaseModel #485

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/sparsezoo/analyze_v1/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class YAMLSerializableBaseModel(BaseModel):
A BaseModel that adds a .yaml(...) function to all child classes
"""

model_config = ConfigDict(protected_namespaces=())

def dict(self, *args, **kwargs) -> Dict[str, Any]:
# alias for model_dump for pydantic v2 upgrade
# to allow for easier migration
return self.model_dump(*args, **kwargs)

def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
"""
:param file_path: optional file path to save yaml to
Expand All @@ -111,7 +118,7 @@ def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
"""
file_stream = None if file_path is None else open(file_path, "w")
ret = yaml.dump(
self.dict(), stream=file_stream, allow_unicode=True, sort_keys=False
self.model_dump(), stream=file_stream, allow_unicode=True, sort_keys=False
)

if file_stream is not None:
Expand All @@ -127,7 +134,7 @@ def parse_yaml_file(cls, file_path: str):
"""
with open(file_path, "r") as file:
dict_obj = yaml.safe_load(file)
return cls.parse_obj(dict_obj)
return cls.model_validate(dict_obj)

@classmethod
def parse_yaml_raw(cls, yaml_raw: str):
Expand All @@ -136,7 +143,7 @@ def parse_yaml_raw(cls, yaml_raw: str):
:return: instance of ModelAnalysis class
"""
dict_obj = yaml.safe_load(yaml_raw) # unsafe: needs to load numpy
return cls.parse_obj(dict_obj)
return cls.model_validate(dict_obj)


@dataclass
Expand Down
63 changes: 7 additions & 56 deletions src/sparsezoo/analyze_v1/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import textwrap
from typing import ClassVar, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, computed_field


__all__ = [
Expand All @@ -33,58 +33,6 @@
PrintOrderType = ClassVar[List[str]]


class PropertyBaseModel(BaseModel):
"""
https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432

Workaround for serializing properties with pydantic until
https://github.com/samuelcolvin/pydantic/issues/935
is solved
"""

@classmethod
def get_properties(cls):
return [
prop
for prop in dir(cls)
if isinstance(getattr(cls, prop), property)
and prop not in ("__values__", "fields")
]

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, # noqa: F821
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny": # noqa: F821
attribs = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
props = self.get_properties()
# Include and exclude properties
if include:
props = [prop for prop in props if prop in include]
if exclude:
props = [prop for prop in props if prop not in exclude]

# Update the attribute dict with the properties
if props:
attribs.update({prop: getattr(self, prop) for prop in props})

return attribs


class NodeCounts(BaseModel):
"""
Pydantic model for specifying the number zero and non-zero operations and the
Expand Down Expand Up @@ -114,7 +62,7 @@ class NodeIO(BaseModel):
)


class ZeroNonZeroParams(PropertyBaseModel):
class ZeroNonZeroParams(BaseModel):
"""
Pydantic model for specifying the number zero and non-zero operations and the
associated sparsity
Expand All @@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel):
description="The number of parameters whose value is zero", default=0
)

@computed_field(repr=True, return_type=Union[int, float])
@property
def sparsity(self):
total_values = self.total
if total_values > 0:
return self.zero / total_values
else:
return 0
return 0.0

@computed_field(repr=True, return_type=int)
@property
def total(self):
return self.non_zero + self.zero


class DenseSparseOps(PropertyBaseModel):
class DenseSparseOps(BaseModel):
"""
Pydantic model for specifying the number dense and sparse operations and the
associated operation sparsity
Expand All @@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel):
default=0,
)

@computed_field(repr=True, return_type=Union[int, float])
@property
def sparsity(self):
total_ops = self.sparse + self.dense
Expand Down
27 changes: 27 additions & 0 deletions tests/sparsezoo/analyze/test_analysis/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sparsezoo.analyze_v1.utils.models import DenseSparseOps, ZeroNonZeroParams


@pytest.mark.parametrize("model", [DenseSparseOps, ZeroNonZeroParams])
@pytest.mark.parametrize("computed_fields", [["sparsity"]])
def test_model_dump_has_computed_fields(model, computed_fields):
model = model()
model_dict = model.model_dump()
for computed_field in computed_fields:
assert computed_field in model_dict
assert model_dict[computed_field] == getattr(model, computed_field)
Loading