Skip to content

Commit

Permalink
Major updates and refactoration of the code
Browse files Browse the repository at this point in the history
  • Loading branch information
eisDNV committed Oct 29, 2024
1 parent f0c6cb9 commit a199a6c
Show file tree
Hide file tree
Showing 24 changed files with 3,436 additions and 570 deletions.
89 changes: 44 additions & 45 deletions case_study/assertion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sympy import Symbol, sympify, FiniteSet
from sympy import Symbol, sympify


class Assertion:
"""Define Assertion objects for checking expectations with respect to simulation results.
Expand All @@ -15,59 +16,59 @@ class Assertion:
expr (str): The boolean expression definition as string.
Any unknown symbol within the expression is defined as sympy.Symbol and is expected to match a variable.
"""
ns = {}

def __init__(self, expr:str):
self._expr = Assertion.do_sympify( expr)

ns: dict = {}

def __init__(self, expr: str):
self._expr = Assertion.do_sympify(expr)
self._symbols = self.get_symbols()
Assertion.update_namespace( self._symbols)
Assertion.update_namespace(self._symbols)

@property
def expr(self):
return self._expr

@property
def symbols(self):
return self._symbols
def symbol(self, name:str):

def symbol(self, name: str):
try:
return self._symbols[name]
except KeyError as err:
except KeyError:
return None

@staticmethod
def do_sympify( _expr):
def do_sympify(_expr):
"""Evaluate the initial expression as sympy expression.
Return the sympified expression or throw an error if sympification is not possible.
"""
if '==' in _expr:
if "==" in _expr:
raise ValueError("'==' cannot be used to check equivalence. Use 'a-b' and check against 0") from None
try:
expr = sympify( _expr)
expr = sympify(_expr)
except ValueError as err:
raise Exception(f"Something wrong with expression {_expr}: {err}|. Cannot sympify.") from None
return expr

def get_symbols(self):
"""Get the atom symbols used in the expression. Return the symbols as dict of `name : symbol`"""
"""Get the atom symbols used in the expression. Return the symbols as dict of `name : symbol`."""
syms = self._expr.atoms(Symbol)
return { s.name : s for s in syms}
return {s.name: s for s in syms}

@staticmethod
def reset():
"""Reset the global dictionary of symbols used by all Assertions"""
"""Reset the global dictionary of symbols used by all Assertions."""
Assertion.ns = {}

@staticmethod
def update_namespace( sym: dict):
"""Ensure that the symbols of this expression are registered in the global namespace `ns`"""
for n,s in sym.items():
def update_namespace(sym: dict):
"""Ensure that the symbols of this expression are registered in the global namespace `ns`."""
for n, s in sym.items():
if n not in Assertion.ns:
Assertion.ns.update({n:s})

Assertion.ns.update({n: s})

def assert_single(self, subs:list[tuple]):
def assert_single(self, subs: list[tuple]):
"""Perform assertion on a single data point.
Args:
Expand All @@ -78,11 +79,10 @@ def assert_single(self, subs:list[tuple]):
Results:
(bool) result of assertion
"""
_subs = [ (self._symbols[s[0]], s[1]) for s in subs]
return self._expr.subs( _subs)

_subs = [(self._symbols[s[0]], s[1]) for s in subs]
return self._expr.subs(_subs)

def assert_series(self, subs:list[tuple], ret:str='bool'):
def assert_series(self, subs: list[tuple], ret: str = "bool"):
"""Perform assertion on a (time) series.
Args:
Expand All @@ -91,7 +91,7 @@ def assert_series(self, subs:list[tuple], ret:str='bool'):
All required variables for the evaluation shall be listed
The variable-name provided as string is translated to its symbol before evaluation.
ret (str)='bool': Determines how to return the result of the assertion:
`bool` : True if any element of the assertion of the series is evaluated to True
`bool-list` : List of True/False for each data point in the series
`interval` : tuple of interval of indices for which the assertion is True
Expand All @@ -100,32 +100,31 @@ def assert_series(self, subs:list[tuple], ret:str='bool'):
bool, list[bool], tuple[int] or int, depending on `ret` parameter.
Default: True/False on whether at least one record is found where the assertion is True.
"""
_subs = [ (self._symbols[s[0]], s[1]) for s in subs]
length = len( subs[0][1])
result = [False]* length
for i in range( length):
_subs = [(self._symbols[s[0]], s[1]) for s in subs]
length = len(subs[0][1])
result = [False] * length

for i in range(length):
s = []
for k in range( len( _subs)): # number of variables in substitution
s.append( (_subs[k][0], _subs[k][1][i]))
res = self._expr.subs( s)
for k in range(len(_subs)): # number of variables in substitution
s.append((_subs[k][0], _subs[k][1][i]))
res = self._expr.subs(s)
if res:
result[i] = True
if ret == 'bool':
if ret == "bool":
return True in result
elif ret == 'bool-list':
elif ret == "bool-list":
return result
elif ret == 'interval':
elif ret == "interval":
if True in result:
idx0 = result.index(True)
if False in result[idx0:]:
return (idx0, idx0+result[idx0:].index(False))
return (idx0, idx0 + result[idx0:].index(False))
else:
return (idx0, length)
else:
return None
elif ret == 'count':
return sum( x for x in result)
elif ret == "count":
return sum(x for x in result)
else:
raise ValueError(f"Unknown return type '{ret}'") from None

Loading

0 comments on commit a199a6c

Please sign in to comment.