Skip to content

Commit

Permalink
diff implemethed
Browse files Browse the repository at this point in the history
  • Loading branch information
juanbc committed Jan 15, 2024
1 parent dd82a57 commit 4d5e570
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 98 deletions.
93 changes: 7 additions & 86 deletions skcriteria/agg/_agg_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from ..core import SKCMethodABC
from ..utils import (
Bunch,
DiffEqualityMixin,
deprecated,
doc_inherit,
diff,
doc_inherit,
npdict_all_equals,
DiffEqualityMixin,
)

# =============================================================================
Expand Down Expand Up @@ -83,7 +83,7 @@ def evaluate(self, dm):
# =============================================================================


class ResultABC(object, metaclass=abc.ABCMeta):
class ResultABC(DiffEqualityMixin, metaclass=abc.ABCMeta):
"""Base class to implement different types of results.
Any evaluation of the DecisionMatrix is expected to result in an object
Expand Down Expand Up @@ -187,7 +187,10 @@ def __len__(self):
"""
return len(self._result_series)

def diff(self, other, rtol=1e-05, atol=1e-08, equal_nan=False):
@doc_inherit(DiffEqualityMixin.diff)
def diff(
self, other, rtol=1e-05, atol=1e-08, equal_nan=False, check_dtype=False
):
def array_allclose(left_value, right_value):
return np.allclose(
left_value,
Expand Down Expand Up @@ -219,88 +222,6 @@ def values_equals(self, other):
and "values" not in the_diff.members_diff
)

def aequals(self, other, rtol=1e-05, atol=1e-08, equal_nan=False):
"""Return True if the result are equal within a tolerance.
The tolerance values are positive, typically very small numbers. The
relative difference (`rtol` * abs(`b`)) and the absolute difference
`atol` are added together to compare against the absolute difference
between `a` and `b`.
NaNs are treated as equal if they are in the same place and if
``equal_nan=True``. Infs are treated as equal if they are in the same
place and of the same sign in both arrays.
The proceeds as follows:
- If ``other`` is the same object return ``True``.
- If ``other`` is not instance of 'DecisionMatrix', has different shape
'criteria', 'alternatives' or 'objectives' returns ``False``.
- Next check the 'weights' and the matrix itself using the provided
tolerance.
Parameters
----------
other : Result
Other result to compare.
rtol : float
The relative tolerance parameter
(see Notes in :py:func:`numpy.allclose`).
atol : float
The absolute tolerance parameter
(see Notes in :py:func:`numpy.allclose`).
equal_nan : bool
Whether to compare NaN's as equal. If True, NaN's in dm will be
considered equal to NaN's in `other` in the output array.
Returns
-------
aequals : :py:class:`bool:py:class:`
Returns True if the two result are equal within the given
tolerance; False otherwise.
See Also
--------
equals, :py:func:`numpy.isclose`, :py:func:`numpy.all`,
:py:func:`numpy.any`, :py:func:`numpy.equal`,
:py:func:`numpy.allclose`.
"""
the_diff = self.diff(other, rtol=rtol, atol=atol, equal_nan=equal_nan)
return not the_diff.has_differences

def equals(self, other):
"""Return True if the results are equal.
This method calls `aquals` without tolerance.
Parameters
----------
other : :py:class:`skcriteria.DecisionMatrix`
Other instance to compare.
Returns
-------
equals : :py:class:`bool:py:class:`
Returns True if the two results are equals.
See Also
--------
aequals, :py:func:`numpy.isclose`, :py:func:`numpy.all`,
:py:func:`numpy.any`, :py:func:`numpy.equal`,
:py:func:`numpy.allclose`.
"""
return self.aequals(other, 0, 0, False)

def __eq__(self, other):
"""x.__eq__(y) <==> x == y <==> x.equals(y)."""
return self.equals(other)

def __ne__(self, other):
"""x.__ne__(y) <==> x != y. <==> not x.equals(y)"""
return not self == other

# REPR ====================================================================

def __repr__(self):
Expand Down
8 changes: 7 additions & 1 deletion skcriteria/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@
from .objectives import Objective
from .plot import DecisionMatrixPlotter
from .stats import DecisionMatrixStatsAccessor
from ..utils import deprecated, df_temporal_header, diff, doc_inherit, DiffEqualityMixin
from ..utils import (
DiffEqualityMixin,
deprecated,
df_temporal_header,
diff,
doc_inherit,
)


# =============================================================================
Expand Down
5 changes: 0 additions & 5 deletions skcriteria/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@
from .utils import hidden

with hidden():
import numpy as np
import numpy.testing as npt

import pandas.testing as pdt

from .agg import ResultABC
from .core import DecisionMatrix
from .cmp import RanksComparator
Expand Down
4 changes: 2 additions & 2 deletions skcriteria/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .deprecate import deprecated, will_change
from .doctools import doc_inherit
from .npdict_cmp import npdict_all_equals
from .object_diff import diff, DiffEqualityMixin
from .object_diff import DiffEqualityMixin, diff
from .unames import unique_names


Expand All @@ -42,6 +42,6 @@
"unique_names",
"will_change",
"diff",
"WithDiff",
"DiffEqualityMixin",
"npdict_all_equals",
]
2 changes: 1 addition & 1 deletion skcriteria/utils/npdict_cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def npdict_all_equals(left, right, rtol=1e-05, atol=1e-08, equal_nan=False):
key = keys.pop()
left_value, right_value = left[key], right[key]

if type(left_value) != type(right_value):
if type(left_value) is not type(right_value):
is_equal = False

elif isinstance(left_value, np.ndarray):
Expand Down
7 changes: 4 additions & 3 deletions skcriteria/utils/object_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,14 @@ class DiffEqualityMixin(abc.ABC):

def __init_subclass__(cls):
"""Validate the creation of a subclass."""
o_params = set(inspect.signature(DiffEqualityMixin.diff).parameters)
params = set(inspect.signature(cls.diff).parameters)
o_params = list(inspect.signature(DiffEqualityMixin.diff).parameters)
params = list(inspect.signature(cls.diff).parameters)
if o_params != params:
o_params.remove("self")
diff_method_name = cls.diff.__qualname__
raise TypeError(
f"{diff_method_name!r} must redefine {o_params!r} parameters"
f"{diff_method_name!r} must redefine exactly "
f"the parameters {o_params!r}"
)

@abc.abstractmethod
Expand Down
108 changes: 108 additions & 0 deletions tests/utils/test_object_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np

import pytest

from skcriteria.utils import object_diff

Expand Down Expand Up @@ -58,6 +59,30 @@ def __init__(self, **kws):
)
assert repr(result) == expected_repr

# reverting the order of the arguments ====================================

result = object_diff.diff(
obj_b, obj_a, a=np.equal, b=np.equal, c=np.equal, d=np.equal
)

assert result.left_type is SomeClass
assert result.right_type is SomeClass
assert result.different_types is False

assert result.has_differences

assert tuple(sorted(result.members_diff)) == ("b", "c")
assert result.members_diff["b"] == (3, 2)
assert result.members_diff["c"] == (4, object_diff.MISSING)

expected_repr = (
"<Difference "
"has_differences=True "
"different_types=False "
"members_diff=('b', 'c')>"
)
assert repr(result) == expected_repr


def test_diff_different_types():
class SomeClass:
Expand Down Expand Up @@ -95,3 +120,86 @@ def __init__(self, **kws):

assert result.different_types is False
assert result.has_differences is False


def test_DiffEqualityMixin():
class SomeClass(object_diff.DiffEqualityMixin):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def diff(
self,
other,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
check_dtype=False,
):
return object_diff.diff(
self, other, a=np.equal, b=np.equal, c=np.equal
)

obj_a = SomeClass(a=1, b=2, d=5)
obj_b = SomeClass(a=1, b=3, c=4, d=6)

result = obj_a.diff(obj_b)

assert result.left_type is SomeClass
assert result.right_type is SomeClass
assert result.different_types is False

assert result.has_differences

assert tuple(sorted(result.members_diff)) == ("b", "c")
assert result.members_diff["b"] == (2, 3)
assert result.members_diff["c"] == (object_diff.MISSING, 4)

expected_repr = (
"<Difference "
"has_differences=True "
"different_types=False "
"members_diff=('b', 'c')>"
)
assert repr(result) == expected_repr

# check the aequals method
assert (obj_a == obj_b) is False
assert obj_a != obj_b
assert obj_a.equals(obj_b) is False
assert obj_a.aequals(obj_b) is False

# reverting the order of the arguments ====================================

result = obj_b.diff(obj_a)

assert result.left_type is SomeClass
assert result.right_type is SomeClass
assert result.different_types is False

assert result.has_differences

assert tuple(sorted(result.members_diff)) == ("b", "c")
assert result.members_diff["b"] == (3, 2)
assert result.members_diff["c"] == (4, object_diff.MISSING)

expected_repr = (
"<Difference "
"has_differences=True "
"different_types=False "
"members_diff=('b', 'c')>"
)
assert repr(result) == expected_repr

# check the aequals method
assert (obj_a == obj_b) is False
assert obj_a != obj_b
assert obj_a.equals(obj_b) is False
assert obj_a.aequals(obj_b) is False


def test_DiffEqualityMixin_invalid_diff_parameters():
with pytest.raises(TypeError):

class SomeClass(object_diff.DiffEqualityMixin):
def diff(self, other):
pass

0 comments on commit 4d5e570

Please sign in to comment.