diff --git a/choice_learn/models/latent_class_base_model.py b/choice_learn/models/latent_class_base_model.py index 1b70f949..adfd7550 100644 --- a/choice_learn/models/latent_class_base_model.py +++ b/choice_learn/models/latent_class_base_model.py @@ -3,6 +3,7 @@ import time import numpy as np +import pandas as pd import tensorflow as tf import tqdm @@ -930,3 +931,37 @@ def get_latent_classes_weights(self): Latent classes weights/probabilities """ return tf.nn.softmax(tf.concat([[tf.constant(0.0)], self.latent_logits], axis=0)) + + def compute_report(self, choice_dataset): + """Compute a report of the estimated weights. + + Parameters + ---------- + choice_dataset : ChoiceDataset + ChoiceDataset used for the estimation of the weights that will be + used to compute the Std Err of this estimation. + + Returns + ------- + pandas.DataFrame + A DF with estimation, Std Err, z_value and p_value for each coefficient. + """ + reports = [] + for i, model in enumerate(self.models): + compute = getattr(model, "compute_report", None) + if callable(compute): + report = model.compute_report(choice_dataset) + report["Latent Class"] = i + reports.append(report) + else: + raise ValueError(f"{i}-th model {model} does not have a compute_report method.") + return pd.concat(reports, axis=0, ignore_index=True)[ + [ + "Latent Class", + "Coefficient Name", + "Coefficient Estimation", + "Std. Err", + "z_value", + "P(.>z)", + ] + ] diff --git a/notebooks/models/latent_class_model.ipynb b/notebooks/models/latent_class_model.ipynb index 62ad7d60..82b4ed8f 100644 --- a/notebooks/models/latent_class_model.ipynb +++ b/notebooks/models/latent_class_model.ipynb @@ -39,6 +39,7 @@ "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", "\n", + "import matplotlib as mpl\n", "import numpy as np\n", "import pandas as pd\n", "\n", @@ -106,9 +107,254 @@ "print(f\"Negative Log-Likelihood: {nll}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "keep_output": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using L-BFGS optimizer, setting up .fit() function\n", + "Using L-BFGS optimizer, setting up .fit() function\n", + "Using L-BFGS optimizer, setting up .fit() function\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_67121/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", + " cmap = mpl.cm.get_cmap(\"Set1\")\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 Latent ClassCoefficient NameCoefficient EstimationStd. Errz_valueP(.>z)
00Weights_items_features_0-0.6756450.023987-28.1671090.000000
10Weights_items_features_1-0.0606040.008162-7.4248490.000000
20Weights_items_features_21.8519510.05491433.7245790.000000
30Weights_items_features_31.3225490.04815927.4624200.000000
40Weights_items_features_4-5.8570890.191162-30.6394600.000000
50Weights_items_features_5-6.5132060.195680-33.2850460.000000
61Weights_items_features_0-1.8175660.077771-23.3707960.000000
71Weights_items_features_1-1.7263650.058838-29.3409860.000000
81Weights_items_features_23.6965670.16025823.0664040.000000
91Weights_items_features_34.1118400.15717926.1602250.000000
101Weights_items_features_4-26.6935163.274723-8.1513810.000000
111Weights_items_features_5-14.9258400.634699-23.5164030.000000
122Weights_items_features_0-2.1047910.104296-20.1810090.000000
132Weights_items_features_1-1.6526220.073820-22.3871880.000000
142Weights_items_features_2-5.5542870.245318-22.6411510.000000
152Weights_items_features_3-13.5655550.544168-24.9289650.000000
162Weights_items_features_4-9.7949300.631004-15.5227810.000000
172Weights_items_features_5-12.1266730.681118-17.8040600.000000
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "report = lc_model.compute_report(elec_dataset)\n", + "\n", + "def format_color_groups(df):\n", + " cmap = mpl.cm.get_cmap(\"Set1\")\n", + " colors = [mpl.colors.rgb2hex(cmap(i)) for i in range(cmap.N)]\n", + " x = df.copy()\n", + " factors = list(x['Latent Class'].unique())\n", + " i = 0\n", + " for factor in factors:\n", + " style = f'background-color: {colors[i]}'\n", + " x.loc[x['Latent Class'] == factor, :] = style\n", + " i += 1\n", + " return x\n", + "\n", + "report.style.apply(format_color_groups, axis=None)" + ] + }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "keep_output": true + }, "source": [ "## Latent Conditional Logit\n", "We used a very simple MNL. Here we simulate the same MNL, by using the Conditional-Logit formulation.\\\n", @@ -281,7 +527,7 @@ ], "metadata": { "kernelspec": { - "display_name": "tf_env", + "display_name": "choice_learn", "language": "python", "name": "python3" }, @@ -295,7 +541,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.8.18" } }, "nbformat": 4, diff --git a/tests/integration_tests/models/test_latent_class.py b/tests/integration_tests/models/test_latent_class.py index 8747ad76..1c49cf63 100644 --- a/tests/integration_tests/models/test_latent_class.py +++ b/tests/integration_tests/models/test_latent_class.py @@ -2,28 +2,32 @@ import tensorflow as tf -from choice_learn.datasets import load_electricity -from choice_learn.models.latent_class_base_model import BaseLatentClassModel -from choice_learn.models.latent_class_mnl import LatentClassConditionalLogit, LatentClassSimpleMNL -from choice_learn.models.simple_mnl import SimpleMNL +tf.config.run_functions_eagerly(True) + +from choice_learn.datasets import load_electricity # noqa: E402 +from choice_learn.models.latent_class_base_model import BaseLatentClassModel # noqa: E402 +from choice_learn.models.latent_class_mnl import ( # noqa: E402 + LatentClassConditionalLogit, + LatentClassSimpleMNL, +) +from choice_learn.models.simple_mnl import SimpleMNL # noqa: E402 elec_dataset = load_electricity(as_frame=False) def test_latent_simple_mnl(): """Test the simple latent class model fit() method.""" - tf.config.run_functions_eagerly(True) lc_model = LatentClassSimpleMNL( - n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=1000, lbfgs_tolerance=1e-8 + n_latent_classes=2, fit_method="mle", optimizer="lbfgs", epochs=1000, lbfgs_tolerance=1e-12 ) _, _ = lc_model.fit(elec_dataset) + lc_model.compute_report(elec_dataset) assert lc_model.evaluate(elec_dataset).numpy() < 1.15 def test_latent_clogit(): """Test the conditional logit latent class model fit() method.""" - tf.config.run_functions_eagerly(True) lc_model = LatentClassConditionalLogit( n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=40, lbfgs_tolerance=1e-8 ) @@ -52,7 +56,6 @@ def test_latent_clogit(): def test_manual_lc(): """Test manual specification of Latent Class Simple MNL model.""" - tf.config.run_functions_eagerly(True) manual_lc = BaseLatentClassModel( model_class=SimpleMNL, model_parameters={"add_exit_choice": False}, @@ -70,7 +73,6 @@ def test_manual_lc(): def test_manual_lc_gd(): """Test manual specification of Latent Class Simple MNL model with gradient descent.""" - tf.config.run_functions_eagerly(True) manual_lc = BaseLatentClassModel( model_class=SimpleMNL, model_parameters={"add_exit_choice": False},