-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added first draft of assertion module. Work in progress
- Loading branch information
Showing
3 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from sympy import Symbol, sympify, FiniteSet | ||
|
||
class Assertion: | ||
"""Define Assertion objects for checking expectations with respect to simulation results. | ||
The class uses sympy, where the symbols are expected to be results variables, | ||
as defined in the variable definition section of Cases. | ||
These can then be combined to boolean expressions and be checked against | ||
single points of a data series (see `assert_single()` or against a whole series (see `assert_series()`). | ||
The symbols used in the expression are accessible as `.symbols` (dict of `name : symbol`). | ||
All symbols used by all defined Assertion objects are accessible as Assertion.ns | ||
Args: | ||
expr (str): The boolean expression definition as string. | ||
Any unknown symbol within the expression is defined as sympy.Symbol and is expected to match a variable. | ||
""" | ||
ns = {} | ||
|
||
def __init__(self, expr:str): | ||
self._expr = Assertion.do_sympify( expr) | ||
self._symbols = self.get_symbols() | ||
Assertion.update_namespace( self._symbols) | ||
|
||
@property | ||
def expr(self): | ||
return self._expr | ||
|
||
@property | ||
def symbols(self): | ||
return self._symbols | ||
|
||
def symbol(self, name:str): | ||
try: | ||
return self._symbols[name] | ||
except KeyError as err: | ||
return None | ||
|
||
@staticmethod | ||
def do_sympify( _expr): | ||
"""Evaluate the initial expression as sympy expression. | ||
Return the sympified expression or throw an error if sympification is not possible. | ||
""" | ||
if '==' in _expr: | ||
raise ValueError("'==' cannot be used to check equivalence. Use 'a-b' and check against 0") from None | ||
try: | ||
expr = sympify( _expr) | ||
except ValueError as err: | ||
raise Exception(f"Something wrong with expression {_expr}: {err}|. Cannot sympify.") from None | ||
return expr | ||
|
||
def get_symbols(self): | ||
"""Get the atom symbols used in the expression. Return the symbols as dict of `name : symbol`""" | ||
syms = self._expr.atoms(Symbol) | ||
return { s.name : s for s in syms} | ||
|
||
@staticmethod | ||
def reset(): | ||
"""Reset the global dictionary of symbols used by all Assertions""" | ||
Assertion.ns = {} | ||
|
||
@staticmethod | ||
def update_namespace( sym: dict): | ||
"""Ensure that the symbols of this expression are registered in the global namespace `ns`""" | ||
for n,s in sym.items(): | ||
if n not in Assertion.ns: | ||
Assertion.ns.update({n:s}) | ||
|
||
|
||
def assert_single(self, subs:list[tuple]): | ||
"""Perform assertion on a single data point. | ||
Args: | ||
subs (list): list of tuples of `(variable-name, value)`, | ||
where the independent variable (normally the time) shall be listed first. | ||
All required variables for the evaluation shall be listed. | ||
The variable-name provided as string is translated to its symbol before evaluation. | ||
Results: | ||
(bool) result of assertion | ||
""" | ||
_subs = [ (self._symbols[s[0]], s[1]) for s in subs] | ||
return self._expr.subs( _subs) | ||
|
||
|
||
def assert_series(self, subs:list[tuple], ret:str='bool'): | ||
"""Perform assertion on a (time) series. | ||
Args: | ||
subs (list): list of tuples of `(variable-symbol, list-of-values)`, | ||
where the independent variable (normally the time) shall be listed first. | ||
All required variables for the evaluation shall be listed | ||
The variable-name provided as string is translated to its symbol before evaluation. | ||
ret (str)='bool': Determines how to return the result of the assertion: | ||
`bool` : True if any element of the assertion of the series is evaluated to True | ||
`bool-list` : List of True/False for each data point in the series | ||
`interval` : tuple of interval of indices for which the assertion is True | ||
`count` : Count the number of points where the assertion is True | ||
Results: | ||
bool, list[bool], tuple[int] or int, depending on `ret` parameter. | ||
Default: True/False on whether at least one record is found where the assertion is True. | ||
""" | ||
_subs = [ (self._symbols[s[0]], s[1]) for s in subs] | ||
length = len( subs[0][1]) | ||
result = [False]* length | ||
|
||
for i in range( length): | ||
s = [] | ||
for k in range( len( _subs)): # number of variables in substitution | ||
s.append( (_subs[k][0], _subs[k][1][i])) | ||
res = self._expr.subs( s) | ||
if res: | ||
result[i] = True | ||
if ret == 'bool': | ||
return True in result | ||
elif ret == 'bool-list': | ||
return result | ||
elif ret == 'interval': | ||
if True in result: | ||
idx0 = result.index(True) | ||
if False in result[idx0:]: | ||
return (idx0, idx0+result[idx0:].index(False)) | ||
else: | ||
return (idx0, length) | ||
else: | ||
return None | ||
elif ret == 'count': | ||
return sum( x for x in result) | ||
else: | ||
raise ValueError(f"Unknown return type '{ret}'") from None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from case_study.assertion import Assertion | ||
from math import sin, cos | ||
import matplotlib.pyplot as plt | ||
import pytest | ||
from sympy import symbols, sympify | ||
|
||
_t = [0.1*float(x) for x in range(100)] | ||
_x = [0.3*sin(t) for t in _t] | ||
_y = [1.0*cos(t) for t in _t] | ||
|
||
def show_data(): | ||
fig, ax = plt.subplots() | ||
ax.plot(_x, _y) | ||
plt.title("Data (_x, _y)", loc="left") | ||
plt.show() | ||
|
||
def test_init(): | ||
Assertion.reset() | ||
t,x,y = symbols("t x y") | ||
ass = Assertion( "t>8") | ||
assert ass.symbols['t'] == t | ||
assert Assertion.ns == {'t':t} | ||
ass = Assertion("(t>8) & (x>0.1)") | ||
assert ass.symbols == {'t':t,'x':x} | ||
assert Assertion.ns == {'t':t, 'x':x} | ||
ass = Assertion( "(y<=4) & (y>=4)") | ||
assert ass.symbols == {'y':y} | ||
assert Assertion.ns == {'t':t, 'x':x, 'y':y} | ||
|
||
|
||
def test_assertion(): | ||
t,x,y = symbols("t x y") | ||
# show_data()print("Analyze", analyze( "t>8 & x>0.1")) | ||
Assertion.reset() | ||
ass = Assertion( "t>8") | ||
assert ass.assert_single( [('t', 9.0)]) | ||
assert not ass.assert_single( [('t', 7)] ) | ||
res = ass.assert_series( [('t',_t)], 'bool-list') | ||
assert True in res, "There is at least one point where the assertion is True" | ||
assert res.index(True) == 81, f"Element {res.index(True)} is True" | ||
assert all( res[i] for i in range(81,100)), "Assertion remains True" | ||
assert ass.assert_series( [('t',_t)], 'bool'), "There is at least one point where the assertion is True" | ||
assert ass.assert_series( [('t',_t)], 'interval') == (81, 100), "Index-interval where the assertion is True" | ||
ass = Assertion("(t>8) & (x>0.1)") | ||
res = ass.assert_series([('t',_t), ('x',_x)]) | ||
assert res, "True at some point" | ||
assert ass.assert_series([('t',_t), ('x',_x)], 'interval') == (81,91) | ||
assert ass.assert_series([('t',_t), ('x',_x)], 'count') == 10 | ||
with pytest.raises(ValueError, match="Unknown return type 'Hello'") as err: | ||
ass.assert_series([('t',_t), ('x',_x)], 'Hello') | ||
# Checking equivalence. '==' does not work | ||
ass = Assertion( "(y<=4) & (y>=4)") | ||
assert ass.symbols == {'y' : y} | ||
assert Assertion.ns == {'t': t, 'x': x, 'y': y} | ||
assert ass.assert_single( [('y', 4)]) | ||
assert not ass.assert_series( [('y', _y)], ret='bool') | ||
with pytest.raises(ValueError, match="'==' cannot be used to check equivalence. Use 'a-b' and check against 0") as err: | ||
ass = Assertion("y==4") | ||
ass = Assertion( "y-4") | ||
assert 0==ass.assert_single( [('y', 4)]) | ||
|
||
|
||
if __name__ == "__main__": | ||
# retcode = pytest.main(["-rA","-v", __file__]) | ||
# assert retcode == 0, f"Non-zero return code {retcode}" | ||
test_init() | ||
test_assertion() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from math import sqrt | ||
from pathlib import Path | ||
|
||
import pytest | ||
from component_model.example_models.bouncing_ball_3d import BouncingBall3D | ||
from component_model.model import Model | ||
from fmpy import plot_result, simulate_fmu | ||
from shutil import copy | ||
|
||
def nearly_equal(res: tuple, expected: tuple, eps=1e-7): | ||
assert len(res) == len( | ||
expected | ||
), f"Tuples of different lengths cannot be equal. Found {len(res)} != {len(expected)}" | ||
for i, (x, y) in enumerate(zip(res, expected, strict=False)): | ||
assert abs(x - y) < eps, f"Element {i} not nearly equal in {x}, {y}" | ||
|
||
def test_make_fmu():#chdir): | ||
fmu_path = Model.build( | ||
str(Path(__file__).parent / "data" / "BouncingBall3D" / "bouncing_ball_3d.py"), | ||
dest = Path( Path.cwd())) | ||
copy( fmu_path, Path(__file__).parent / "data" / "BouncingBall3D") | ||
|
||
|
||
def test_run_fmpy(show): | ||
"""Test and validate the basic BouncingBall using fmpy and not using OSP or case_study.""" | ||
path = Path("BouncingBall3D.fmu") | ||
assert path.exists(), f"File {path} does not exist" | ||
dt = 0.01 | ||
result = simulate_fmu( | ||
path, | ||
start_time=0.0, | ||
stop_time=3.0, | ||
step_size=dt, | ||
validate=True, | ||
solver="Euler", | ||
debug_logging=False, | ||
visible=True, | ||
logger=print, # fmi_call_logger=print, | ||
start_values={ | ||
"e": 0.71, | ||
"g": 9.81, | ||
}, | ||
) | ||
if show: | ||
plot_result(result) | ||
t_bounce = sqrt(2*10*0.0254 / 9.81) | ||
v_bounce = 9.81 * t_bounce # speed in z-direction | ||
x_bounce = t_bounce/1.0 # x-position where it bounces in m | ||
# Note: default values are reported at time 0! | ||
nearly_equal(result[0], (0, 0, 0, 10, 1, 0, 0, sqrt(2*10/9.81), 0, 0)) # time,pos-3, speed-3, p_bounce-3 | ||
print(result[1]) | ||
arrays_equal(result(bb), (0.01, | ||
0.01, 0, (10*0.0254-0.5*9.81*0.01**2)/0.0254, | ||
1, 0, -9.81*0.01, sqrt(2*10*0.0254/9.81), 0, 0)) | ||
t_before = int(sqrt(2 / 9.81) / dt) * dt # just before bounce | ||
print("BEFORE", t_before, result[int(t_before / dt)]) | ||
nearly_equal( | ||
result[int(t_before / dt)], | ||
(t_before, 1*t_before, 0, 1.0 - 0.5 * 9.81 * t_before * t_before, 1, 0, -9.81 * t_before, x_bounce, 0, 0), | ||
eps=0.003, | ||
) | ||
nearly_equal( | ||
result[int(t_before / dt) + 1], | ||
( | ||
t_before + dt, | ||
v_bounce * 0.71 * (t_before + dt - t_bounce) - 0.5 * 9.81 * (t_before + dt - t_bounce) ** 2, | ||
v_bounce * 0.71 - 9.81 * (t_before + dt - t_bounce), | ||
), | ||
eps=0.03, | ||
) | ||
nearly_equal(result[int(2.5 / dt)], (2.5, 0, 0), eps=0.4) | ||
nearly_equal(result[int(3 / dt)], (3, 0, 0)) | ||
print("RESULT", result[int(t_before / dt) + 1]) | ||
|
||
|
||
if __name__ == "__main__": | ||
# retcode = pytest.main(["-rA", "-v", __file__, "--show", "True"]) | ||
# assert retcode == 0, f"Non-zero return code {retcode}" | ||
import os | ||
os.chdir(Path(__file__).parent.absolute() / "test_working_directory") | ||
test_make_fmu() | ||
test_run_fmpy( show=True) |