From dbdffdc15873dcc05c1ae40d9d7bfe2fdddc6ee3 Mon Sep 17 00:00:00 2001 From: Rodrigo Neto Date: Mon, 13 Jan 2025 19:05:21 -0300 Subject: [PATCH] Add equality checks to the UQ result readers. ASIM-5939 --- src/alfasim_sdk/result_reader/aggregator.py | 14 ++++++- src/alfasim_sdk/result_reader/reader.py | 44 +++++++++++++++++++-- tests/results/test_result_reader.py | 28 +++++++++++++ 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/alfasim_sdk/result_reader/aggregator.py b/src/alfasim_sdk/result_reader/aggregator.py index a2d455ff..156c17e3 100644 --- a/src/alfasim_sdk/result_reader/aggregator.py +++ b/src/alfasim_sdk/result_reader/aggregator.py @@ -114,7 +114,7 @@ class TimeSetInfoItem(NamedTuple): uuid: str -TimeSetInfo = Dict[int, TimeSetInfoItem] +TimeSetInfo = Dict[TimeStepIndex, TimeSetInfoItem] _PROFILE_ID_ATTR = "profile_id" @@ -1957,7 +1957,7 @@ def read_uncertainty_propagation_analyses_meta_data( ) -@attr.s(frozen=True) +@attr.s(frozen=True, eq=False) class UPResult: """ Holder for each uncertainty propagation result. @@ -1967,6 +1967,16 @@ class UPResult: std_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([]))) mean_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([]))) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, UPResult): + return False + + return ( + np.array_equal(self.realization_output, other.realization_output) + and np.array_equal(self.std_result, other.std_result) + and np.array_equal(self.mean_result, other.mean_result) + ) + def read_uncertainty_propagation_results( metadata: UncertaintyPropagationAnalysesMetaData, diff --git a/src/alfasim_sdk/result_reader/reader.py b/src/alfasim_sdk/result_reader/reader.py index 1a3c1492..5aaa2760 100644 --- a/src/alfasim_sdk/result_reader/reader.py +++ b/src/alfasim_sdk/result_reader/reader.py @@ -355,7 +355,7 @@ def validator(inst: Any, attribute: attr.Attribute, value: Any) -> None: return validator -@define(frozen=True) +@define(frozen=True, eq=False) class GlobalSensitivityAnalysisResults: timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1)) coefficients: dict[GSAOutputKey, np.ndarray] = attr.field( @@ -393,6 +393,22 @@ def get_sensitivity_curve( domain = Array(self.timeset, "s") return Curve(image=image, domain=domain) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, GlobalSensitivityAnalysisResults): + return False + + return ( + np.array_equal(self.timeset, other.timeset) + and _all_dict_close(self.coefficients, other.coefficients) + and self.metadata == other.metadata + ) + + +def _all_dict_close(a: dict[Any, np.ndarray], b: dict[Any, np.ndarray]) -> bool: + if a.keys() != b.keys(): + return False + return all(np.array_equal(a[key], b[key]) for key in a) + @define(frozen=True) class _BaseHistoryMatchingResults: @@ -423,7 +439,7 @@ def from_directory(cls, result_dir: Path) -> Self | None: ) -@define(frozen=True) +@define(frozen=True, eq=False) class HistoryMatchingProbabilisticResults(_BaseHistoryMatchingResults): probabilistic_distributions: dict[HMOutputKey, np.ndarray] = attr.field( validator=_non_empty_dict_validator(values_type=np.ndarray) @@ -443,6 +459,18 @@ def from_directory(cls, result_dir: Path) -> Self | None: metadata=metadata, ) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, HistoryMatchingProbabilisticResults): + return False + + return ( + _all_dict_close( + self.probabilistic_distributions, other.probabilistic_distributions + ) + and self.historic_data_curves == other.historic_data_curves + and self.metadata == other.metadata + ) + def _read_curves_data( metadata: HistoryMatchingMetadata, @@ -460,7 +488,7 @@ def _read_curves_data( return result -@define(frozen=True) +@define(frozen=True, eq=False) class UncertaintyPropagationResults: timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1)) results: dict[UPOutputKey, UPResult] = attr.field( @@ -481,3 +509,13 @@ def from_directory(cls, result_dir: Path) -> Self | None: results=read_uncertainty_propagation_results(metadata), metadata=metadata, ) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, UncertaintyPropagationResults): + return False + + return ( + np.array_equal(self.timeset, other.timeset) + and self.results == other.results + and self.metadata == other.metadata + ) diff --git a/tests/results/test_result_reader.py b/tests/results/test_result_reader.py index 73501ed2..6b934b92 100644 --- a/tests/results/test_result_reader.py +++ b/tests/results/test_result_reader.py @@ -3,6 +3,7 @@ import json from pathlib import Path +import attr import numpy import numpy as np import pytest @@ -188,6 +189,11 @@ def test_global_sensitivity_analysis_results_reader( qoi_data_index=0, ) + # Test equality check. + assert results == attr.evolve(results) + assert results != object() + assert results != attr.evolve(results, timeset=np.array([0.1, 0.2])) + # Ensure the reader can handle a nonexistent result file. results = GlobalSensitivityAnalysisResults.from_directory(Path("foo")) assert results is None @@ -204,6 +210,13 @@ def test_deterministic_reader(self, hm_deterministic_results_dir: Path) -> None: } self._validate_meta_and_historic_curves(results) + # Test equality check. + assert results == attr.evolve(results) + assert results != object() + assert results != attr.evolve( + results, deterministic_values={HMOutputKey("parametric_var_1"): 0.1} + ) + def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None: results = HistoryMatchingProbabilisticResults.from_directory( hm_probabilistic_results_dir @@ -217,6 +230,16 @@ def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None: ) self._validate_meta_and_historic_curves(results) + # Test equality check. + assert results == attr.evolve(results) + assert results != object() + assert results != attr.evolve( + results, + probabilistic_distributions={ + HMOutputKey("parametric_var_1"): np.array([0.1, 0.3]) + }, + ) + def test_wrong_result_file( self, hm_probabilistic_results_dir: Path, hm_deterministic_results_dir: Path ) -> None: @@ -331,6 +354,11 @@ def test_uncertainty_propagation_results_reader(up_results_dir: Path) -> None: sample_indexes=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]], ) + # Test equality check. + assert reader == attr.evolve(reader) + assert reader != object() + assert reader != attr.evolve(reader, timeset=np.array([0.1, 0.2])) + # Ensure the reader can handle a nonexistent result file. reader = UncertaintyPropagationResults.from_directory(Path("foo")) assert reader is None