Skip to content

Commit

Permalink
Test update
Browse files Browse the repository at this point in the history
-`test_fit_model_confidence_intervals` increased the number of bootstraps so it passes
-`test_fit_modelUncert` expanded to test the bootstrapped
  • Loading branch information
HKaras committed Nov 27, 2024
1 parent 76cfed1 commit 6e41ffe
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions test/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def mock_x():

@pytest.fixture(scope='module')
def mock_data(mock_x):
return bigauss(mock_x,mean1=3,mean2=4,std1=0.5,std2=0.2,amp1=0.5,amp2=0.6)
data = bigauss(mock_x,mean1=3,mean2=4,std1=0.5,std2=0.2,amp1=0.5,amp2=0.6)
data += whitegaussnoise(mock_x,0.01,seed=1)
return data

model_types = ['parametric','semiparametric','semiparametric_vec',
'nonparametric','nonparametric_vec','nonparametric_vec_normalized']
Expand Down Expand Up @@ -511,7 +513,7 @@ def test_fit_model_confidence_intervals(mock_data,model_type,method):
model = _generate_model(model_type, fixed_axis=True)

if method=='bootstrap':
results = fit(model,mock_data + whitegaussnoise(x,0.01,seed=1), bootstrap=3)
results = fit(model,mock_data + whitegaussnoise(x,0.01,seed=1), bootstrap=100)
else:
results = fit(model,mock_data + whitegaussnoise(x,0.01,seed=1))

Expand Down Expand Up @@ -551,6 +553,7 @@ def test_fit_evaluate_model(mock_data,mock_x,model_type):
@pytest.mark.parametrize('method', ['bootstrap','moment'])
@pytest.mark.parametrize('model_type', model_types)
def test_fit_modelUncert(mock_data,mock_x,model_type,method):
"Check that the uncertainty of fit results can be calculated and is the uncertainty of the model is non zero for all but nonparametric models"
model = _generate_model(model_type, fixed_axis=False)

if method=='bootstrap':
Expand All @@ -562,6 +565,9 @@ def test_fit_modelUncert(mock_data,mock_x,model_type,method):
ci_lower = results.modelUncert.ci(95)[:,0]
ci_upper = results.modelUncert.ci(95)[:,1]
assert np.less_equal(ci_lower,ci_upper).all()
if model_type != 'nonparametric' and model_type != 'nonparametric_vec' and model_type != 'nonparametric_vec_normalized':
assert np.all(np.round(ci_lower) <= np.round(results.model)) and np.less(ci_lower.sum(),results.model.sum())
assert np.all(np.round(ci_upper,5) >= np.round(results.model,5)) and np.greater(ci_upper.sum(),results.model.sum())

# ================================================================

Expand Down

0 comments on commit 6e41ffe

Please sign in to comment.