Skip to content

Commit

Permalink
npdict_cmp
Browse files Browse the repository at this point in the history
  • Loading branch information
juanbc committed Jan 13, 2024
1 parent 79723b3 commit ec24860
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 24 deletions.
58 changes: 37 additions & 21 deletions skcriteria/agg/_agg_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
import pandas as pd

from ..core import SKCMethodABC
from ..utils import Bunch, deprecated, doc_inherit
from ..utils import (
Bunch,
deprecated,
doc_inherit,
diff,
npdict_all_equals,
WithDiff,
)

# =============================================================================
# DM BASE
Expand Down Expand Up @@ -76,7 +83,7 @@ def evaluate(self, dm):
# =============================================================================


class ResultABC(metaclass=abc.ABCMeta):
class ResultABC(object, metaclass=abc.ABCMeta):
"""Base class to implement different types of results.
Any evaluation of the DecisionMatrix is expected to result in an object
Expand Down Expand Up @@ -180,6 +187,26 @@ def __len__(self):
"""
return len(self._result_series)

def diff(self, other, rtol=1e-05, atol=1e-08, equal_nan=False):
def array_allclose(left_value, right_value):
return np.allclose(
left_value,
right_value,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)

members = {
"method": np.array_equal,
"alternatives": np.array_equal,
"values": array_allclose,
"extra_": dict_equal,
}

the_diff = diff(self, other, **members)
return the_diff

def values_equals(self, other):
"""Check if the alternatives and values are the same.
Expand Down Expand Up @@ -240,24 +267,13 @@ def aequals(self, other, rtol=1e-05, atol=1e-08, equal_nan=False):
"""
if self is other:
return True
is_veq = self.values_equals(other) and set(self._extra) == set(
other._extra
is_veq = self.values_equals(other) and npdict_all_equals(
self.extra_,
other.extra_,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
keys = set(self._extra)
while is_veq and keys:
k = keys.pop()
sv = self._extra[k]
ov = other._extra[k]
if isinstance(ov, np.ndarray):
is_veq = is_veq and np.allclose(
sv,
ov,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
else:
is_veq = is_veq and sv == ov
return is_veq

def equals(self, other):
Expand Down Expand Up @@ -285,11 +301,11 @@ def equals(self, other):
return self.aequals(other, 0, 0, False)

def __eq__(self, other):
"""x.__eq__(y) <==> x == y."""
"""x.__eq__(y) <==> x == y <==> x.equals(y)."""
return self.equals(other)

def __ne__(self, other):
"""x.__eq__(y) <==> x == y."""
"""x.__ne__(y) <==> x != y. <==> not x.equals(y)"""
return not self == other

# REPR ====================================================================
Expand Down
12 changes: 10 additions & 2 deletions skcriteria/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .objectives import Objective
from .plot import DecisionMatrixPlotter
from .stats import DecisionMatrixStatsAccessor
from ..utils import deprecated, df_temporal_header, diff, doc_inherit
from ..utils import deprecated, df_temporal_header, diff, doc_inherit, WithDiff


# =============================================================================
Expand Down Expand Up @@ -128,7 +128,7 @@ def __getitem__(self, slc):
# =============================================================================


class DecisionMatrix:
class DecisionMatrix(object):
"""Representation of all data needed in the MCDA analysis.
This object gathers everything necessary to represent a data set used
Expand Down Expand Up @@ -773,6 +773,14 @@ def equals(self, other):
other, rtol=0, atol=0, equal_nan=False, check_dtype=True
)

def __eq__(self, other):
"""x.__eq__(y) <==> x == y <==> x.equals(y)."""
return self.equals(other)

def __ne__(self, other):
"""x.__ne__(y) <==> x != y. <==> not x.equals(y)"""
return not self == other

# SLICES ==================================================================

def __getitem__(self, slc):
Expand Down
5 changes: 4 additions & 1 deletion skcriteria/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from .cmanagers import df_temporal_header, hidden
from .deprecate import deprecated, will_change
from .doctools import doc_inherit
from .object_diff import diff
from .npdict_cmp import npdict_all_equals
from .object_diff import diff, WithDiff
from .unames import unique_names


Expand All @@ -41,4 +42,6 @@
"unique_names",
"will_change",
"diff",
"WithDiff",
"npdict_all_equals",
]
55 changes: 55 additions & 0 deletions skcriteria/utils/npdict_cmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# License: BSD-3 (https://tldrlegal.com/license/bsd-3-clause-license-(revised))
# Copyright (c) 2016-2021, Cabral, Juan; Luczywo, Nadia
# Copyright (c) 2022, 2023, 2024 QuatroPe
# All rights reserved.

# =============================================================================
# DOCS
# =============================================================================

"""Utilities to compare two dictionaries with numpy arrays."""

# =============================================================================
# IMPORTS
# =============================================================================

import numpy as np

# =============================================================================
# CLASSES
# =============================================================================


def npdict_all_equals(left, right, rtol=1e-05, atol=1e-08, equal_nan=False):
"""Return True if the two dictionaries are equal within a tolerance."""

# extra keys
keys = set(left).union(right)

# if the keys are not the same on both sides return False
if not (len(keys) == len(left) == len(right)):
return False

is_equal = True # flag to check if all keys are equal, optimist
while is_equal and keys: # loop until all keys are equal
key = keys.pop()
left_value, right_value = left[key], right[key]

if type(left_value) != type(right_value):
is_equal = False

elif isinstance(left_value, np.ndarray):
is_equal = np.allclose(
left_value,
right_value,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)

else:
is_equal = left_value == right_value

return is_equal
10 changes: 10 additions & 0 deletions skcriteria/utils/object_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@
# IMPORTS
# =============================================================================

import abc
from dataclasses import dataclass, field

# =============================================================================
# CLASSES
# =============================================================================


class WithDiff(abc.ABC):
"""Mixin to add a difference attribute."""

@abc.abstractmethod
def diff(self, other):
"""Returns the difference between two objects."""
raise NotImplementedError()


class _Missing(object):
def __new__(cls, *args, **kwargs):
"""Creates a new instance of the class if it does not already exist, \
Expand Down

0 comments on commit ec24860

Please sign in to comment.