diff --git a/choice_learn/models/latent_class_base_model.py b/choice_learn/models/latent_class_base_model.py index 1b70f949..adfd7550 100644 --- a/choice_learn/models/latent_class_base_model.py +++ b/choice_learn/models/latent_class_base_model.py @@ -3,6 +3,7 @@ import time import numpy as np +import pandas as pd import tensorflow as tf import tqdm @@ -930,3 +931,37 @@ def get_latent_classes_weights(self): Latent classes weights/probabilities """ return tf.nn.softmax(tf.concat([[tf.constant(0.0)], self.latent_logits], axis=0)) + + def compute_report(self, choice_dataset): + """Compute a report of the estimated weights. + + Parameters + ---------- + choice_dataset : ChoiceDataset + ChoiceDataset used for the estimation of the weights that will be + used to compute the Std Err of this estimation. + + Returns + ------- + pandas.DataFrame + A DF with estimation, Std Err, z_value and p_value for each coefficient. + """ + reports = [] + for i, model in enumerate(self.models): + compute = getattr(model, "compute_report", None) + if callable(compute): + report = model.compute_report(choice_dataset) + report["Latent Class"] = i + reports.append(report) + else: + raise ValueError(f"{i}-th model {model} does not have a compute_report method.") + return pd.concat(reports, axis=0, ignore_index=True)[ + [ + "Latent Class", + "Coefficient Name", + "Coefficient Estimation", + "Std. Err", + "z_value", + "P(.>z)", + ] + ] diff --git a/notebooks/models/latent_class_model.ipynb b/notebooks/models/latent_class_model.ipynb index 62ad7d60..82b4ed8f 100644 --- a/notebooks/models/latent_class_model.ipynb +++ b/notebooks/models/latent_class_model.ipynb @@ -39,6 +39,7 @@ "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", "\n", + "import matplotlib as mpl\n", "import numpy as np\n", "import pandas as pd\n", "\n", @@ -106,9 +107,254 @@ "print(f\"Negative Log-Likelihood: {nll}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "keep_output": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using L-BFGS optimizer, setting up .fit() function\n", + "Using L-BFGS optimizer, setting up .fit() function\n", + "Using L-BFGS optimizer, setting up .fit() function\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_67121/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n", + " cmap = mpl.cm.get_cmap(\"Set1\")\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " | Latent Class | \n", + "Coefficient Name | \n", + "Coefficient Estimation | \n", + "Std. Err | \n", + "z_value | \n", + "P(.>z) | \n", + "
---|---|---|---|---|---|---|
0 | \n", + "0 | \n", + "Weights_items_features_0 | \n", + "-0.675645 | \n", + "0.023987 | \n", + "-28.167109 | \n", + "0.000000 | \n", + "
1 | \n", + "0 | \n", + "Weights_items_features_1 | \n", + "-0.060604 | \n", + "0.008162 | \n", + "-7.424849 | \n", + "0.000000 | \n", + "
2 | \n", + "0 | \n", + "Weights_items_features_2 | \n", + "1.851951 | \n", + "0.054914 | \n", + "33.724579 | \n", + "0.000000 | \n", + "
3 | \n", + "0 | \n", + "Weights_items_features_3 | \n", + "1.322549 | \n", + "0.048159 | \n", + "27.462420 | \n", + "0.000000 | \n", + "
4 | \n", + "0 | \n", + "Weights_items_features_4 | \n", + "-5.857089 | \n", + "0.191162 | \n", + "-30.639460 | \n", + "0.000000 | \n", + "
5 | \n", + "0 | \n", + "Weights_items_features_5 | \n", + "-6.513206 | \n", + "0.195680 | \n", + "-33.285046 | \n", + "0.000000 | \n", + "
6 | \n", + "1 | \n", + "Weights_items_features_0 | \n", + "-1.817566 | \n", + "0.077771 | \n", + "-23.370796 | \n", + "0.000000 | \n", + "
7 | \n", + "1 | \n", + "Weights_items_features_1 | \n", + "-1.726365 | \n", + "0.058838 | \n", + "-29.340986 | \n", + "0.000000 | \n", + "
8 | \n", + "1 | \n", + "Weights_items_features_2 | \n", + "3.696567 | \n", + "0.160258 | \n", + "23.066404 | \n", + "0.000000 | \n", + "
9 | \n", + "1 | \n", + "Weights_items_features_3 | \n", + "4.111840 | \n", + "0.157179 | \n", + "26.160225 | \n", + "0.000000 | \n", + "
10 | \n", + "1 | \n", + "Weights_items_features_4 | \n", + "-26.693516 | \n", + "3.274723 | \n", + "-8.151381 | \n", + "0.000000 | \n", + "
11 | \n", + "1 | \n", + "Weights_items_features_5 | \n", + "-14.925840 | \n", + "0.634699 | \n", + "-23.516403 | \n", + "0.000000 | \n", + "
12 | \n", + "2 | \n", + "Weights_items_features_0 | \n", + "-2.104791 | \n", + "0.104296 | \n", + "-20.181009 | \n", + "0.000000 | \n", + "
13 | \n", + "2 | \n", + "Weights_items_features_1 | \n", + "-1.652622 | \n", + "0.073820 | \n", + "-22.387188 | \n", + "0.000000 | \n", + "
14 | \n", + "2 | \n", + "Weights_items_features_2 | \n", + "-5.554287 | \n", + "0.245318 | \n", + "-22.641151 | \n", + "0.000000 | \n", + "
15 | \n", + "2 | \n", + "Weights_items_features_3 | \n", + "-13.565555 | \n", + "0.544168 | \n", + "-24.928965 | \n", + "0.000000 | \n", + "
16 | \n", + "2 | \n", + "Weights_items_features_4 | \n", + "-9.794930 | \n", + "0.631004 | \n", + "-15.522781 | \n", + "0.000000 | \n", + "
17 | \n", + "2 | \n", + "Weights_items_features_5 | \n", + "-12.126673 | \n", + "0.681118 | \n", + "-17.804060 | \n", + "0.000000 | \n", + "