From ec24860f53d18618e43b879e65a1d9b724e8c137 Mon Sep 17 00:00:00 2001 From: juanbc Date: Sat, 13 Jan 2024 02:12:59 -0300 Subject: [PATCH] npdict_cmp --- skcriteria/agg/_agg_base.py | 58 +++++++++++++++++++++------------ skcriteria/core/data.py | 12 +++++-- skcriteria/utils/__init__.py | 5 ++- skcriteria/utils/npdict_cmp.py | 55 +++++++++++++++++++++++++++++++ skcriteria/utils/object_diff.py | 10 ++++++ 5 files changed, 116 insertions(+), 24 deletions(-) create mode 100644 skcriteria/utils/npdict_cmp.py diff --git a/skcriteria/agg/_agg_base.py b/skcriteria/agg/_agg_base.py index d6cc272..4e9d2b9 100644 --- a/skcriteria/agg/_agg_base.py +++ b/skcriteria/agg/_agg_base.py @@ -24,7 +24,14 @@ import pandas as pd from ..core import SKCMethodABC -from ..utils import Bunch, deprecated, doc_inherit +from ..utils import ( + Bunch, + deprecated, + doc_inherit, + diff, + npdict_all_equals, + WithDiff, +) # ============================================================================= # DM BASE @@ -76,7 +83,7 @@ def evaluate(self, dm): # ============================================================================= -class ResultABC(metaclass=abc.ABCMeta): +class ResultABC(object, metaclass=abc.ABCMeta): """Base class to implement different types of results. Any evaluation of the DecisionMatrix is expected to result in an object @@ -180,6 +187,26 @@ def __len__(self): """ return len(self._result_series) + def diff(self, other, rtol=1e-05, atol=1e-08, equal_nan=False): + def array_allclose(left_value, right_value): + return np.allclose( + left_value, + right_value, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ) + + members = { + "method": np.array_equal, + "alternatives": np.array_equal, + "values": array_allclose, + "extra_": dict_equal, + } + + the_diff = diff(self, other, **members) + return the_diff + def values_equals(self, other): """Check if the alternatives and values are the same. @@ -240,24 +267,13 @@ def aequals(self, other, rtol=1e-05, atol=1e-08, equal_nan=False): """ if self is other: return True - is_veq = self.values_equals(other) and set(self._extra) == set( - other._extra + is_veq = self.values_equals(other) and npdict_all_equals( + self.extra_, + other.extra_, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, ) - keys = set(self._extra) - while is_veq and keys: - k = keys.pop() - sv = self._extra[k] - ov = other._extra[k] - if isinstance(ov, np.ndarray): - is_veq = is_veq and np.allclose( - sv, - ov, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - ) - else: - is_veq = is_veq and sv == ov return is_veq def equals(self, other): @@ -285,11 +301,11 @@ def equals(self, other): return self.aequals(other, 0, 0, False) def __eq__(self, other): - """x.__eq__(y) <==> x == y.""" + """x.__eq__(y) <==> x == y <==> x.equals(y).""" return self.equals(other) def __ne__(self, other): - """x.__eq__(y) <==> x == y.""" + """x.__ne__(y) <==> x != y. <==> not x.equals(y)""" return not self == other # REPR ==================================================================== diff --git a/skcriteria/core/data.py b/skcriteria/core/data.py index 51fd0c3..2403732 100644 --- a/skcriteria/core/data.py +++ b/skcriteria/core/data.py @@ -35,7 +35,7 @@ from .objectives import Objective from .plot import DecisionMatrixPlotter from .stats import DecisionMatrixStatsAccessor -from ..utils import deprecated, df_temporal_header, diff, doc_inherit +from ..utils import deprecated, df_temporal_header, diff, doc_inherit, WithDiff # ============================================================================= @@ -128,7 +128,7 @@ def __getitem__(self, slc): # ============================================================================= -class DecisionMatrix: +class DecisionMatrix(object): """Representation of all data needed in the MCDA analysis. This object gathers everything necessary to represent a data set used @@ -773,6 +773,14 @@ def equals(self, other): other, rtol=0, atol=0, equal_nan=False, check_dtype=True ) + def __eq__(self, other): + """x.__eq__(y) <==> x == y <==> x.equals(y).""" + return self.equals(other) + + def __ne__(self, other): + """x.__ne__(y) <==> x != y. <==> not x.equals(y)""" + return not self == other + # SLICES ================================================================== def __getitem__(self, slc): diff --git a/skcriteria/utils/__init__.py b/skcriteria/utils/__init__.py index aa8efe2..3161dfe 100644 --- a/skcriteria/utils/__init__.py +++ b/skcriteria/utils/__init__.py @@ -21,7 +21,8 @@ from .cmanagers import df_temporal_header, hidden from .deprecate import deprecated, will_change from .doctools import doc_inherit -from .object_diff import diff +from .npdict_cmp import npdict_all_equals +from .object_diff import diff, WithDiff from .unames import unique_names @@ -41,4 +42,6 @@ "unique_names", "will_change", "diff", + "WithDiff", + "npdict_all_equals", ] diff --git a/skcriteria/utils/npdict_cmp.py b/skcriteria/utils/npdict_cmp.py new file mode 100644 index 0000000..5933725 --- /dev/null +++ b/skcriteria/utils/npdict_cmp.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# License: BSD-3 (https://tldrlegal.com/license/bsd-3-clause-license-(revised)) +# Copyright (c) 2016-2021, Cabral, Juan; Luczywo, Nadia +# Copyright (c) 2022, 2023, 2024 QuatroPe +# All rights reserved. + +# ============================================================================= +# DOCS +# ============================================================================= + +"""Utilities to compare two dictionaries with numpy arrays.""" + +# ============================================================================= +# IMPORTS +# ============================================================================= + +import numpy as np + +# ============================================================================= +# CLASSES +# ============================================================================= + + +def npdict_all_equals(left, right, rtol=1e-05, atol=1e-08, equal_nan=False): + """Return True if the two dictionaries are equal within a tolerance.""" + + # extra keys + keys = set(left).union(right) + + # if the keys are not the same on both sides return False + if not (len(keys) == len(left) == len(right)): + return False + + is_equal = True # flag to check if all keys are equal, optimist + while is_equal and keys: # loop until all keys are equal + key = keys.pop() + left_value, right_value = left[key], right[key] + + if type(left_value) != type(right_value): + is_equal = False + + elif isinstance(left_value, np.ndarray): + is_equal = np.allclose( + left_value, + right_value, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + ) + + else: + is_equal = left_value == right_value + + return is_equal diff --git a/skcriteria/utils/object_diff.py b/skcriteria/utils/object_diff.py index e5b46d7..1242676 100644 --- a/skcriteria/utils/object_diff.py +++ b/skcriteria/utils/object_diff.py @@ -15,6 +15,7 @@ # IMPORTS # ============================================================================= +import abc from dataclasses import dataclass, field # ============================================================================= @@ -22,6 +23,15 @@ # ============================================================================= +class WithDiff(abc.ABC): + """Mixin to add a difference attribute.""" + + @abc.abstractmethod + def diff(self, other): + """Returns the difference between two objects.""" + raise NotImplementedError() + + class _Missing(object): def __new__(cls, *args, **kwargs): """Creates a new instance of the class if it does not already exist, \