From 068239f72c3ff8e302b3e273cbe09ff7ac6d65f2 Mon Sep 17 00:00:00 2001 From: Vincent Auriau Date: Wed, 17 Jul 2024 09:50:03 +0200 Subject: [PATCH] Enh/tqdm bar (#128) * ENH: description of tqdm bars with losses --- choice_learn/models/base_model.py | 28 +++++++++++----- docs/index.md | 37 +++++++++++++++++---- notebooks/introduction/3_model_clogit.ipynb | 6 ++-- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/choice_learn/models/base_model.py b/choice_learn/models/base_model.py index 11efaa4a..24d51843 100644 --- a/choice_learn/models/base_model.py +++ b/choice_learn/models/base_model.py @@ -1,4 +1,5 @@ """Base class for choice models.""" + import json import logging import os @@ -315,6 +316,11 @@ def fit( # Optimization Steps epoch_losses.append(neg_loglikelihood) + if verbose > 0: + inner_range.set_description( + f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}" + ) + # In this case we do not need to batch the sample_weights else: if verbose > 0: @@ -346,6 +352,11 @@ def fit( # Optimization Steps epoch_losses.append(neg_loglikelihood) + if verbose > 0: + inner_range.set_description( + f"Epoch Negative-LogLikeliHood: {np.sum(epoch_losses):.4f}" + ) + # Take into account last batch that may have a differnt length into account for # the computation of the epoch loss. if batch_size != -1: @@ -358,13 +369,13 @@ def fit( else: epoch_loss = tf.reduce_mean(epoch_losses) losses_history["train_loss"].append(epoch_loss) - desc = f"Epoch {epoch_nb} Train Loss {losses_history['train_loss'][-1].numpy()}" + print_loss = losses_history["train_loss"][-1].numpy() + desc = f"Epoch {epoch_nb} Train Loss {print_loss:.4f}" if verbose > 1: print( - f"Loop {epoch_nb} Time", - time.time() - t_start, - "Loss:", - tf.reduce_sum(epoch_losses).numpy(), + f"Loop {epoch_nb} Time:", + f"{time.time() - t_start:.4f}", + f"Loss: {print_loss:.4f}", ) # Test on val_dataset if provided @@ -393,7 +404,7 @@ def fit( test_loss = tf.reduce_mean(test_losses) if verbose > 1: print("Test Negative-LogLikelihood:", test_loss.numpy()) - desc += f", Test Loss {test_loss.numpy()}" + desc += f", Test Loss {np.round(test_loss.numpy(), 4)}" losses_history["test_loss"] = losses_history.get("test_loss", []) + [ test_loss.numpy() ] @@ -404,9 +415,8 @@ def fit( if self.stop_training: print("Early Stopping taking effect") break - if verbose > 0: - t_range.set_description(desc) - t_range.refresh() + t_range.set_description(desc) + t_range.refresh() temps_logs = {k: tf.reduce_mean(v) for k, v in train_logs.items()} self.callbacks.on_train_end(logs=temps_logs) diff --git a/docs/index.md b/docs/index.md index ea12cd28..edd5ecab 100644 --- a/docs/index.md +++ b/docs/index.md @@ -123,14 +123,39 @@ If you consider this package and any of its feature useful for your research, pl The use of this software is under the MIT license, with no limitation of usage, including for commercial applications. -### Contributors - -### Special Thanks - ### Affiliations This package has been developped within the [Artefact Research Center](https://www.artefact.com/data-consulting-transformation/artefact-research-center/) in collaboration with CentraleSupélec, université Paris-Saclay. -[![](./illustrations/logos/logo_arc.png)](https://www.artefact.com/data-consulting-transformation/artefact-research-center/) | [![](./illustrations/logos/artefact_logo.png)](https://www.artefact.com/) | [![](./illustrations/logos/logo_CS.png)](https://mics.centralesupelec.fr/) | [![](./illustrations/logos/logo_paris_saclay.png)](https://www.universite-paris-saclay.fr/) -:-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: +

+ + + +   +   + + + +

+ +

+ + + +   +   + + + +   +   + + + +   +   + + + +

diff --git a/notebooks/introduction/3_model_clogit.ipynb b/notebooks/introduction/3_model_clogit.ipynb index 65604c71..4e69e85c 100644 --- a/notebooks/introduction/3_model_clogit.ipynb +++ b/notebooks/introduction/3_model_clogit.ipynb @@ -1211,9 +1211,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2000/2000 [00:09<00:00, 210.33it/s]\n", - "100%|██████████| 4000/4000 [00:16<00:00, 237.61it/s]\n", - "100%|██████████| 20000/20000 [01:24<00:00, 236.20it/s]\n" + "Epoch 1999 Train Loss 0.6801: 100%|██████████| 2000/2000 [00:13<00:00, 148.87it/s]\n", + "Epoch 3999 Train Loss 0.6776: 100%|██████████| 4000/4000 [00:24<00:00, 160.24it/s]\n", + "Epoch 19999 Train Loss 0.6767: 100%|██████████| 20000/20000 [01:59<00:00, 166.89it/s]\n" ] } ],