diff --git a/pyproject.toml b/pyproject.toml index 24cb86b2..f45bb2e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,8 @@ classifiers = [ python = "^3.9.0" numpy = "^1.24.3" pandas = "^1.5.3" -tensorflow = "^2.11.0" -tensorflow-probability = "^0.20.1" +tensorflow = "^2.14.0" +tensorflow-probability = "^0.22.1" tqdm = "^4.0.0" ortools = { version = "^9.6", optional = true } gurobipy = { version = "^11.0", optional = true } diff --git a/requirements.txt b/requirements.txt index 07c7368e..c86e4c8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -e . numpy==1.24.3 pandas==1.5.3 -tensorflow==2.13.0 -tensorflow_probability==0.20.1 +tensorflow==2.14.0 +tensorflow_probability==0.22.1 tqdm==4.65.0 diff --git a/tests/integration_tests/models/test_conditional_logit.py b/tests/integration_tests/models/test_conditional_logit.py index 82b3d9d1..3be2b1c3 100644 --- a/tests/integration_tests/models/test_conditional_logit.py +++ b/tests/integration_tests/models/test_conditional_logit.py @@ -33,7 +33,6 @@ def test_mode_canada_gt(): gt_model.instantiate(canada_dataset) gt_model.trainable_weights = gt_weights - assert (gt_model.evaluate(canada_dataset) * len(canada_dataset)) <= 1874.4, gt_model.evaluate( - canada_dataset - ) * len(canada_dataset) - assert (gt_model.evaluate(canada_dataset) * len(canada_dataset)) >= 1874.1 + total_nll = gt_model.evaluate(canada_dataset) * len(canada_dataset) + assert total_nll <= 1874.4, f"Got NLL: {total_nll}" + assert total_nll >= 1874.1, f"Got NLL: {total_nll}" diff --git a/tests/integration_tests/models/test_simple_mnl.py b/tests/integration_tests/models/test_simple_mnl.py new file mode 100644 index 00000000..c84d901d --- /dev/null +++ b/tests/integration_tests/models/test_simple_mnl.py @@ -0,0 +1,40 @@ +"""Tests SimpleMNL.""" + +from choice_learn.datasets import load_swissmetro +from choice_learn.models import SimpleMNL + +dataset = load_swissmetro() + + +def test_simple_mnl_lbfgs_fit_with_lbfgs(): + """Tests that SimpleMNL can fit with LBFGS.""" + global dataset + + model = SimpleMNL(epochs=20) + model.fit(dataset) + model.evaluate(dataset) + assert model.evaluate(dataset) < 1.0 + + +def test_simple_mnl_lbfgs_fit_with_adam(): + """Tests that SimpleMNL can fit with Adam.""" + global dataset + + model = SimpleMNL(epochs=20, optimizer="adam", batch_size=256) + model.fit(dataset) + model.evaluate(dataset) + assert model.evaluate(dataset) < 1.0 + + +def test_that_endpoints_run(): + """Dummy test to check that the endpoints run. + + No verification of results. + """ + global dataset + + model = SimpleMNL(epochs=20) + model.fit(dataset) + model.evaluate(dataset) + model.predict_probas(dataset) + assert True