Skip to content

Commit

Permalink
Add tests for CVODES and IDAS using existing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aragilar committed Dec 8, 2024
1 parent 7ac3140 commit 3d336f7
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 44 deletions.
21 changes: 17 additions & 4 deletions packages/scikits-odes/src/scikits/odes/tests/test_get_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def rhs(x, y, ydot):
#ydot[:] = (np.cos(x) * (x + 0.1) - np.sin(x)) / np.pow((x + 0.1), 2)


class GetInfoTest(unittest.TestCase):
class GetInfoTestCVODE(unittest.TestCase):
solvername = 'cvode'

def setUp(self):
self.ode = ode('cvode', rhs, old_api=False)
self.ode = ode(self.solvername, rhs, old_api=False)
self.solution = self.ode.solve(xs, np.array([1]))

def test_we_integrated_correctly(self):
Expand All @@ -45,13 +47,24 @@ def test_ode_exposes_num_rhs_evals(self):
assert 'NumRhsEvals' in info
assert info['NumRhsEvals'] > 0

class GetInfoTestSpils(unittest.TestCase):

class GetInfoTestCVODES(GetInfoTestCVODE):
solvername = 'cvodes'


class GetInfoTestSpilsCVODE(unittest.TestCase):
solvername = 'cvode'

def setUp(self):
self.ode = ode('cvode', rhs, linsolver="spgmr", old_api=False)
self.ode = ode(self.solvername, rhs, linsolver="spgmr", old_api=False)
self.solution = self.ode.solve(xs, np.array([1]))

def test_ode_exposes_num_njtimes_evals(self):
info = self.ode.get_info()
print("ode.get_info() =\n", info)
assert 'NumJtimesEvals' in info
assert info['NumJtimesEvals'] > 0


class GetInfoTestSpilsCVODES(GetInfoTestSpilsCVODE):
solvername = 'cvodes'
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,16 @@ def ontstop_vc(t, y, solver):

return 0

class TestOn(TestCase):
class TestOnCVODE(TestCase):
"""
Check integrate.dae
"""
solvername = 'cvode'

def test_cvode_rootfn_noroot(self):
#test calling sequence. End is reached before root is found
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.SUCCESS, "ERROR: Error occurred"
Expand All @@ -144,7 +145,7 @@ def test_cvode_rootfn_noroot(self):
def test_cvode_rootfn(self):
#test root finding and stopping: End is reached at a root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.ROOT_RETURN, "ERROR: Root not found!"
Expand All @@ -155,7 +156,7 @@ def test_cvode_rootfn(self):
def test_cvode_rootfnacc(self):
#test root finding and accumilating: End is reached normally, roots stored
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_va,
old_api=False)
soln = solver.solve(tspan, y0)
Expand All @@ -171,7 +172,7 @@ def test_cvode_rootfnacc(self):
def test_cvode_rootfn_stop(self):
#test root finding and stopping: End is reached at a root with a function
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vb,
old_api=False)
soln = solver.solve(tspan, y0)
Expand All @@ -183,7 +184,7 @@ def test_cvode_rootfn_stop(self):
def test_cvode_rootfn_test(self):
#test root finding and accumilating: End is reached after a number of root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0)
Expand All @@ -199,7 +200,7 @@ def test_cvode_rootfn_test(self):
def test_cvode_rootfn_two(self):
#test two root finding
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
solver = ode(self.solvername, rhs_fn, nr_rootfns=2, rootfn=root_fn2,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0)
Expand All @@ -215,7 +216,7 @@ def test_cvode_rootfn_two(self):
def test_cvode_rootfn_end(self):
#test root finding with root at endtime
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
solver = ode(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn3,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0)
Expand All @@ -233,7 +234,7 @@ def test_cvode_tstopfn_notstop(self):
global n
n = 0
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
solver = ode(self.solvername, rhs_fn, tstop=T1+1, ontstop=ontstop_va,
old_api=False)

soln = solver.solve(tspan, y0)
Expand All @@ -247,7 +248,7 @@ def test_cvode_tstopfn(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1,
solver = ode(self.solvername, rhs_fn, tstop=T1,
old_api=False)
soln = solver.solve(tspan, y0)
assert soln.flag==StatusEnum.TSTOP_RETURN, "ERROR: Tstop not found!"
Expand All @@ -264,7 +265,7 @@ def test_cvode_tstopfnacc(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_va,
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_va,
old_api=False)
soln = solver.solve(tspan, y0)
assert len(soln.tstop.t) == 9, "ERROR: Did not find all tstop"
Expand All @@ -282,7 +283,7 @@ def test_cvode_tstopfn_stop(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vb,
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vb,
old_api=False)

soln = solver.solve(tspan, y0)
Expand All @@ -302,7 +303,7 @@ def test_cvode_tstopfn_test(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = ode('cvode', rhs_fn, tstop=T1, ontstop=ontstop_vc,
solver = ode(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vc,
old_api=False)

soln = solver.solve(tspan, y0)
Expand All @@ -315,3 +316,7 @@ def test_cvode_tstopfn_test(self):
assert allclose([soln.tstop.t[-1], soln.tstop.y[-1,0], soln.tstop.y[-1,1]],
[30.0, -1452.5024, -294.30],
atol=atol, rtol=rtol)


class TestOnCVODES(TestOnCVODE):
solvername = 'cvodes'
31 changes: 18 additions & 13 deletions packages/scikits-odes/src/scikits/odes/tests/test_on_funcs_ida.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,16 @@ def ontstop_vc(t, y, ydot, solver):

return 0

class TestOn(TestCase):
class TestOnIDA(TestCase):
"""
Check integrate.dae
"""
solvername = 'ida'

def test_ida_rootfn_noroot(self):
#test calling sequence. End is reached before root is found
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
Expand All @@ -146,7 +147,7 @@ def test_ida_rootfn_noroot(self):
def test_ida_rootfn(self):
#test root finding and stopping: End is reached at a root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.ROOT_RETURN, "ERROR: Root not found!"
Expand All @@ -157,7 +158,7 @@ def test_ida_rootfn(self):
def test_ida_rootfnacc(self):
#test root finding and accumilating: End is reached normally, roots stored
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_va,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
Expand All @@ -173,7 +174,7 @@ def test_ida_rootfnacc(self):
def test_ida_rootfn_stop(self):
#test root finding and stopping: End is reached at a root with a function
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vb,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
Expand All @@ -185,7 +186,7 @@ def test_ida_rootfn_stop(self):
def test_ida_rootfn_test(self):
#test root finding and accumilating: End is reached after a number of root
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
Expand All @@ -201,7 +202,7 @@ def test_ida_rootfn_test(self):
def test_ida_rootfn_two(self):
#test two root finding
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=2, rootfn=root_fn2,
solver = dae(self.solvername, rhs_fn, nr_rootfns=2, rootfn=root_fn2,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
Expand All @@ -217,7 +218,7 @@ def test_ida_rootfn_two(self):
def test_ida_rootfn_end(self):
#test root finding with root at endtime
tspan = np.arange(0, 30 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, nr_rootfns=1, rootfn=root_fn3,
solver = dae(self.solvername, rhs_fn, nr_rootfns=1, rootfn=root_fn3,
onroot=onroot_vc,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
Expand All @@ -235,7 +236,7 @@ def test_ida_tstopfn_notstop(self):
global n
n = 0
tspan = np.arange(0, t_end1 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1+1, ontstop=ontstop_va,
solver = dae(self.solvername, rhs_fn, tstop=T1+1, ontstop=ontstop_va,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.SUCCESS, "ERROR: Error occurred"
Expand All @@ -248,7 +249,7 @@ def test_ida_tstopfn(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1,
solver = dae(self.solvername, rhs_fn, tstop=T1,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
assert soln.flag==StatusEnumIDA.TSTOP_RETURN, "ERROR: Tstop not found!"
Expand All @@ -265,7 +266,7 @@ def test_ida_tstopfnacc(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_va,
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_va,
old_api=False)
soln = solver.solve(tspan, y0, yp0)
assert len(soln.tstop.t) == 9, "ERROR: Did not find all tstop"
Expand All @@ -283,7 +284,7 @@ def test_ida_tstopfn_stop(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vb,
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vb,
old_api=False)

soln = solver.solve(tspan, y0, yp0)
Expand All @@ -303,7 +304,7 @@ def test_ida_tstopfn_test(self):
global n
n = 0
tspan = np.arange(0, t_end2 + 1, 1.0, DTYPE)
solver = dae('ida', rhs_fn, tstop=T1, ontstop=ontstop_vc,
solver = dae(self.solvername, rhs_fn, tstop=T1, ontstop=ontstop_vc,
old_api=False)

soln = solver.solve(tspan, y0, yp0)
Expand All @@ -316,3 +317,7 @@ def test_ida_tstopfn_test(self):
assert allclose([soln.tstop.t[-1], soln.tstop.y[-1,0], soln.tstop.y[-1,1]],
[30.0, -1452.5024, -294.30],
atol=atol, rtol=rtol)


class TestOnIDAS(TestOnIDA):
solvername = 'idas'
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ def jac_vec_error_immediate(v, Jv, t, y):
return -1

class TestCVodeReturn(TestCase):

def __init__(self, *args, **kwargs):
super(TestCVodeReturn, self).__init__(*args, **kwargs)
self.solvername = "cvode"
solvername = "cvode"

def test_normal_rhs(self):
solver = ode(self.solvername, normal_rhs, old_api=False)
Expand Down Expand Up @@ -312,6 +309,4 @@ def test_jac_vec_error_immediate(self):
)

class TestCVodesReturn(TestCVodeReturn):
def __init__(self, *args, **kwargs):
super(TestCVodesReturn, self).__init__(*args, **kwargs)
self.solvername = "cvodes"
solvername = "cvodes"
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def jac_error_immediate(t, y, ydot, residual, cj, J):
return -1

class TestIdaReturn(TestCase):

def __init__(self, *args, **kwargs):
super(TestIdaReturn, self).__init__(*args, **kwargs)
self.solvername = "ida"
solvername = "ida"

def test_normal_rhs(self):
solver = dae(self.solvername, normal_rhs, old_api=False)
Expand Down Expand Up @@ -235,6 +232,4 @@ def test_jac_error_immediate(self):


class TestIdasReturn(TestIdaReturn):
def __init__(self, *args, **kwargs):
super(TestIdasReturn, self).__init__(*args, **kwargs)
self.solvername = "idas"
solvername = "idas"

0 comments on commit 3d336f7

Please sign in to comment.