-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ENH: small enhancements in README & Doc * ENH: higher TF & TFP version * ADD: few SimpleMNL tests
- Loading branch information
1 parent
24102aa
commit 610165a
Showing
4 changed files
with
47 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
-e . | ||
numpy==1.24.3 | ||
pandas==1.5.3 | ||
tensorflow==2.13.0 | ||
tensorflow_probability==0.20.1 | ||
tensorflow==2.14.0 | ||
tensorflow_probability==0.22.1 | ||
tqdm==4.65.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Tests SimpleMNL.""" | ||
|
||
from choice_learn.datasets import load_swissmetro | ||
from choice_learn.models import SimpleMNL | ||
|
||
dataset = load_swissmetro() | ||
|
||
|
||
def test_simple_mnl_lbfgs_fit_with_lbfgs(): | ||
"""Tests that SimpleMNL can fit with LBFGS.""" | ||
global dataset | ||
|
||
model = SimpleMNL(epochs=20) | ||
model.fit(dataset) | ||
model.evaluate(dataset) | ||
assert model.evaluate(dataset) < 1.0 | ||
|
||
|
||
def test_simple_mnl_lbfgs_fit_with_adam(): | ||
"""Tests that SimpleMNL can fit with Adam.""" | ||
global dataset | ||
|
||
model = SimpleMNL(epochs=20, optimizer="adam", batch_size=256) | ||
model.fit(dataset) | ||
model.evaluate(dataset) | ||
assert model.evaluate(dataset) < 1.0 | ||
|
||
|
||
def test_that_endpoints_run(): | ||
"""Dummy test to check that the endpoints run. | ||
No verification of results. | ||
""" | ||
global dataset | ||
|
||
model = SimpleMNL(epochs=20) | ||
model.fit(dataset) | ||
model.evaluate(dataset) | ||
model.predict_probas(dataset) | ||
assert True |