Skip to content

Commit

Permalink
Merge pull request #417 from ESSS/fb-ASIM-5939-fix-uq-sub-simulation-…
Browse files Browse the repository at this point in the history
…access

Add equality checks to the UQ result readers.
  • Loading branch information
ro-oliveira95 authored Jan 21, 2025
2 parents f04c1a6 + dbdffdc commit 4b8ea14
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
14 changes: 12 additions & 2 deletions src/alfasim_sdk/result_reader/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TimeSetInfoItem(NamedTuple):
uuid: str


TimeSetInfo = Dict[int, TimeSetInfoItem]
TimeSetInfo = Dict[TimeStepIndex, TimeSetInfoItem]


_PROFILE_ID_ATTR = "profile_id"
Expand Down Expand Up @@ -1957,7 +1957,7 @@ def read_uncertainty_propagation_analyses_meta_data(
)


@attr.s(frozen=True)
@attr.s(frozen=True, eq=False)
class UPResult:
"""
Holder for each uncertainty propagation result.
Expand All @@ -1967,6 +1967,16 @@ class UPResult:
std_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([])))
mean_result: np.ndarray = attr.ib(default=attr.Factory(lambda: np.array([])))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, UPResult):
return False

return (
np.array_equal(self.realization_output, other.realization_output)
and np.array_equal(self.std_result, other.std_result)
and np.array_equal(self.mean_result, other.mean_result)
)


def read_uncertainty_propagation_results(
metadata: UncertaintyPropagationAnalysesMetaData,
Expand Down
44 changes: 41 additions & 3 deletions src/alfasim_sdk/result_reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def validator(inst: Any, attribute: attr.Attribute, value: Any) -> None:
return validator


@define(frozen=True)
@define(frozen=True, eq=False)
class GlobalSensitivityAnalysisResults:
timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1))
coefficients: dict[GSAOutputKey, np.ndarray] = attr.field(
Expand Down Expand Up @@ -393,6 +393,22 @@ def get_sensitivity_curve(
domain = Array(self.timeset, "s")
return Curve(image=image, domain=domain)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, GlobalSensitivityAnalysisResults):
return False

return (
np.array_equal(self.timeset, other.timeset)
and _all_dict_close(self.coefficients, other.coefficients)
and self.metadata == other.metadata
)


def _all_dict_close(a: dict[Any, np.ndarray], b: dict[Any, np.ndarray]) -> bool:
if a.keys() != b.keys():
return False
return all(np.array_equal(a[key], b[key]) for key in a)


@define(frozen=True)
class _BaseHistoryMatchingResults:
Expand Down Expand Up @@ -423,7 +439,7 @@ def from_directory(cls, result_dir: Path) -> Self | None:
)


@define(frozen=True)
@define(frozen=True, eq=False)
class HistoryMatchingProbabilisticResults(_BaseHistoryMatchingResults):
probabilistic_distributions: dict[HMOutputKey, np.ndarray] = attr.field(
validator=_non_empty_dict_validator(values_type=np.ndarray)
Expand All @@ -443,6 +459,18 @@ def from_directory(cls, result_dir: Path) -> Self | None:
metadata=metadata,
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, HistoryMatchingProbabilisticResults):
return False

return (
_all_dict_close(
self.probabilistic_distributions, other.probabilistic_distributions
)
and self.historic_data_curves == other.historic_data_curves
and self.metadata == other.metadata
)


def _read_curves_data(
metadata: HistoryMatchingMetadata,
Expand All @@ -460,7 +488,7 @@ def _read_curves_data(
return result


@define(frozen=True)
@define(frozen=True, eq=False)
class UncertaintyPropagationResults:
timeset: np.ndarray = attr.field(validator=attr.validators.min_len(1))
results: dict[UPOutputKey, UPResult] = attr.field(
Expand All @@ -481,3 +509,13 @@ def from_directory(cls, result_dir: Path) -> Self | None:
results=read_uncertainty_propagation_results(metadata),
metadata=metadata,
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, UncertaintyPropagationResults):
return False

return (
np.array_equal(self.timeset, other.timeset)
and self.results == other.results
and self.metadata == other.metadata
)
28 changes: 28 additions & 0 deletions tests/results/test_result_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from pathlib import Path

import attr
import numpy
import numpy as np
import pytest
Expand Down Expand Up @@ -188,6 +189,11 @@ def test_global_sensitivity_analysis_results_reader(
qoi_data_index=0,
)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(results, timeset=np.array([0.1, 0.2]))

# Ensure the reader can handle a nonexistent result file.
results = GlobalSensitivityAnalysisResults.from_directory(Path("foo"))
assert results is None
Expand All @@ -204,6 +210,13 @@ def test_deterministic_reader(self, hm_deterministic_results_dir: Path) -> None:
}
self._validate_meta_and_historic_curves(results)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(
results, deterministic_values={HMOutputKey("parametric_var_1"): 0.1}
)

def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None:
results = HistoryMatchingProbabilisticResults.from_directory(
hm_probabilistic_results_dir
Expand All @@ -217,6 +230,16 @@ def test_probabilistic_reader(self, hm_probabilistic_results_dir: Path) -> None:
)
self._validate_meta_and_historic_curves(results)

# Test equality check.
assert results == attr.evolve(results)
assert results != object()
assert results != attr.evolve(
results,
probabilistic_distributions={
HMOutputKey("parametric_var_1"): np.array([0.1, 0.3])
},
)

def test_wrong_result_file(
self, hm_probabilistic_results_dir: Path, hm_deterministic_results_dir: Path
) -> None:
Expand Down Expand Up @@ -331,6 +354,11 @@ def test_uncertainty_propagation_results_reader(up_results_dir: Path) -> None:
sample_indexes=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]],
)

# Test equality check.
assert reader == attr.evolve(reader)
assert reader != object()
assert reader != attr.evolve(reader, timeset=np.array([0.1, 0.2]))

# Ensure the reader can handle a nonexistent result file.
reader = UncertaintyPropagationResults.from_directory(Path("foo"))
assert reader is None

0 comments on commit 4b8ea14

Please sign in to comment.