Skip to content

Commit

Permalink
Enh/tqdm bar (#128)
Browse files Browse the repository at this point in the history
* ENH: description of tqdm bars with losses
  • Loading branch information
VincentAuriau authored Jul 17, 2024
1 parent ed59165 commit 068239f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
28 changes: 19 additions & 9 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class for choice models."""

import json
import logging
import os
Expand Down Expand Up @@ -315,6 +316,11 @@ def fit(
# Optimization Steps
epoch_losses.append(neg_loglikelihood)

if verbose > 0:
inner_range.set_description(
f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}"
)

# In this case we do not need to batch the sample_weights
else:
if verbose > 0:
Expand Down Expand Up @@ -346,6 +352,11 @@ def fit(
# Optimization Steps
epoch_losses.append(neg_loglikelihood)

if verbose > 0:
inner_range.set_description(
f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}"
)

# Take into account last batch that may have a differnt length into account for
# the computation of the epoch loss.
if batch_size != -1:
Expand All @@ -358,13 +369,13 @@ def fit(
else:
epoch_loss = tf.reduce_mean(epoch_losses)
losses_history["train_loss"].append(epoch_loss)
desc = f"Epoch {epoch_nb} Train Loss {losses_history['train_loss'][-1].numpy()}"
print_loss = losses_history["train_loss"][-1].numpy()
desc = f"Epoch {epoch_nb} Train Loss {print_loss:.4f}"
if verbose > 1:
print(
f"Loop {epoch_nb} Time",
time.time() - t_start,
"Loss:",
tf.reduce_sum(epoch_losses).numpy(),
f"Loop {epoch_nb} Time:",
f"{time.time() - t_start:.4f}",
f"Loss: {print_loss:.4f}",
)

# Test on val_dataset if provided
Expand Down Expand Up @@ -393,7 +404,7 @@ def fit(
test_loss = tf.reduce_mean(test_losses)
if verbose > 1:
print("Test Negative-LogLikelihood:", test_loss.numpy())
desc += f", Test Loss {test_loss.numpy()}"
desc += f", Test Loss {np.round(test_loss.numpy(), 4)}"
losses_history["test_loss"] = losses_history.get("test_loss", []) + [
test_loss.numpy()
]
Expand All @@ -404,9 +415,8 @@ def fit(
if self.stop_training:
print("Early Stopping taking effect")
break
if verbose > 0:
t_range.set_description(desc)
t_range.refresh()
t_range.set_description(desc)
t_range.refresh()

temps_logs = {k: tf.reduce_mean(v) for k, v in train_logs.items()}
self.callbacks.on_train_end(logs=temps_logs)
Expand Down
37 changes: 31 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,39 @@ If you consider this package and any of its feature useful for your research, pl

The use of this software is under the MIT license, with no limitation of usage, including for commercial applications.

### Contributors

### Special Thanks

### Affiliations

This package has been developped within the [Artefact Research Center](https://www.artefact.com/data-consulting-transformation/artefact-research-center/) in collaboration with CentraleSupélec, université Paris-Saclay.


[![](./illustrations/logos/logo_arc.png)](https://www.artefact.com/data-consulting-transformation/artefact-research-center/) | [![](./illustrations/logos/artefact_logo.png)](https://www.artefact.com/) | [![](./illustrations/logos/logo_CS.png)](https://mics.centralesupelec.fr/) | [![](./illustrations/logos/logo_paris_saclay.png)](https://www.universite-paris-saclay.fr/)
:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------:
<p align="center">
<a href="https://www.artefact.com/data-consulting-transformation/artefact-research-center/">
<img src="./docs/illustrations/logos/logo_arc.png" height="60" />
</a>
&emsp;
&emsp;
<a href="https://www.artefact.com/">
<img src="docs/illustrations/logos/logo_atf.png" height="65" />
</a>
</p>

<p align="center">
<a href="https://www.universite-paris-saclay.fr/">
<img src="./docs/illustrations/logos/logo_paris_saclay.png" height="60" />
</a>
&emsp;
&emsp;
<a href="https://mics.centralesupelec.fr/">
<img src="docs/illustrations/logos/logo_CS.png" height="60" />
</a>
&emsp;
&emsp;
<a href="https://www.london.edu/">
<img src="docs/illustrations/logos/logo_lbs.jpeg" height="60" />
</a>
&emsp;
&emsp;
<a href="https://www.insead.edu/">
<img src="docs/illustrations/logos/logo_insead.png" height="60" />
</a>
</p>
6 changes: 3 additions & 3 deletions notebooks/introduction/3_model_clogit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1211,9 +1211,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 2000/2000 [00:09<00:00, 210.33it/s]\n",
"100%|██████████| 4000/4000 [00:16<00:00, 237.61it/s]\n",
"100%|██████████| 20000/20000 [01:24<00:00, 236.20it/s]\n"
"Epoch 1999 Train Loss 0.6801: 100%|██████████| 2000/2000 [00:13<00:00, 148.87it/s]\n",
"Epoch 3999 Train Loss 0.6776: 100%|██████████| 4000/4000 [00:24<00:00, 160.24it/s]\n",
"Epoch 19999 Train Loss 0.6767: 100%|██████████| 20000/20000 [01:59<00:00, 166.89it/s]\n"
]
}
],
Expand Down

0 comments on commit 068239f

Please sign in to comment.