Skip to content

Commit

Permalink
Fix/tf tfp compat (#115)
Browse files Browse the repository at this point in the history
* ENH: small enhancements in README & Doc

* ENH: higher TF & TFP version

* ADD: few SimpleMNL tests
  • Loading branch information
VincentAuriau authored Jul 1, 2024
1 parent 24102aa commit 610165a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ classifiers = [
python = "^3.9.0"
numpy = "^1.24.3"
pandas = "^1.5.3"
tensorflow = "^2.11.0"
tensorflow-probability = "^0.20.1"
tensorflow = "^2.14.0"
tensorflow-probability = "^0.22.1"
tqdm = "^4.0.0"
ortools = { version = "^9.6", optional = true }
gurobipy = { version = "^11.0", optional = true }
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-e .
numpy==1.24.3
pandas==1.5.3
tensorflow==2.13.0
tensorflow_probability==0.20.1
tensorflow==2.14.0
tensorflow_probability==0.22.1
tqdm==4.65.0
7 changes: 3 additions & 4 deletions tests/integration_tests/models/test_conditional_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_mode_canada_gt():
gt_model.instantiate(canada_dataset)

gt_model.trainable_weights = gt_weights
assert (gt_model.evaluate(canada_dataset) * len(canada_dataset)) <= 1874.4, gt_model.evaluate(
canada_dataset
) * len(canada_dataset)
assert (gt_model.evaluate(canada_dataset) * len(canada_dataset)) >= 1874.1
total_nll = gt_model.evaluate(canada_dataset) * len(canada_dataset)
assert total_nll <= 1874.4, f"Got NLL: {total_nll}"
assert total_nll >= 1874.1, f"Got NLL: {total_nll}"
40 changes: 40 additions & 0 deletions tests/integration_tests/models/test_simple_mnl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Tests SimpleMNL."""

from choice_learn.datasets import load_swissmetro
from choice_learn.models import SimpleMNL

dataset = load_swissmetro()


def test_simple_mnl_lbfgs_fit_with_lbfgs():
"""Tests that SimpleMNL can fit with LBFGS."""
global dataset

model = SimpleMNL(epochs=20)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0


def test_simple_mnl_lbfgs_fit_with_adam():
"""Tests that SimpleMNL can fit with Adam."""
global dataset

model = SimpleMNL(epochs=20, optimizer="adam", batch_size=256)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0


def test_that_endpoints_run():
"""Dummy test to check that the endpoints run.
No verification of results.
"""
global dataset

model = SimpleMNL(epochs=20)
model.fit(dataset)
model.evaluate(dataset)
model.predict_probas(dataset)
assert True

0 comments on commit 610165a

Please sign in to comment.