Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD: Latent Class Report #201

Merged
merged 23 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
92fd73d
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau May 17, 2024
8e8bca4
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau May 23, 2024
b3455e2
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau May 28, 2024
aa35f2f
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau May 30, 2024
5fbbaca
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Jun 25, 2024
c8ec7c2
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Jul 2, 2024
18d6c25
ENH: logos in ReadMe
VincentAuriau Jul 3, 2024
19f04af
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Jul 4, 2024
03c5b35
merge
VincentAuriau Jul 8, 2024
57bdf4b
poetry lock [--no-update]Merge branch 'main' of github.com:artefactor…
VincentAuriau Jul 31, 2024
1266ad9
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Aug 21, 2024
053977d
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Sep 11, 2024
52b2356
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Oct 21, 2024
78517ec
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Oct 22, 2024
40db13b
FIX: forgotten remainings of renaming of tolerance argument into lbfg…
VincentAuriau Nov 28, 2024
f5b4171
ADD: updates notebook
VincentAuriau Nov 28, 2024
92288f2
Merge branch 'main' of github.com:artefactory/choice-learn-private
VincentAuriau Dec 20, 2024
d584ae5
ADD: report in Latent Class model
VincentAuriau Dec 20, 2024
1f22fa5
ADD: LC report tests
VincentAuriau Dec 20, 2024
e478395
ADD: LC report tests
VincentAuriau Dec 20, 2024
b632b58
Merge branch 'lc-report' of github.com:artefactory/choice-learn-priva…
VincentAuriau Dec 23, 2024
38f281a
ADD: small improvements
VincentAuriau Dec 23, 2024
ad43b25
few modifs
VincentAuriau Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions choice_learn/models/latent_class_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tqdm

Expand Down Expand Up @@ -930,3 +931,37 @@ def get_latent_classes_weights(self):
Latent classes weights/probabilities
"""
return tf.nn.softmax(tf.concat([[tf.constant(0.0)], self.latent_logits], axis=0))

def compute_report(self, choice_dataset):
"""Compute a report of the estimated weights.

Parameters
----------
choice_dataset : ChoiceDataset
ChoiceDataset used for the estimation of the weights that will be
used to compute the Std Err of this estimation.

Returns
-------
pandas.DataFrame
A DF with estimation, Std Err, z_value and p_value for each coefficient.
"""
reports = []
for i, model in enumerate(self.models):
compute = getattr(model, "compute_report", None)
if callable(compute):
report = model.compute_report(choice_dataset)
report["Latent Class"] = i
reports.append(report)
else:
raise ValueError(f"{i}-th model {model} does not have a compute_report method.")
return pd.concat(reports, axis=0, ignore_index=True)[
[
"Latent Class",
"Coefficient Name",
"Coefficient Estimation",
"Std. Err",
"z_value",
"P(.>z)",
]
]
252 changes: 249 additions & 3 deletions notebooks/models/latent_class_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
"\n",
"import matplotlib as mpl\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
Expand Down Expand Up @@ -106,9 +107,254 @@
"print(f\"Negative Log-Likelihood: {nll}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"keep_output": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using L-BFGS optimizer, setting up .fit() function\n",
"Using L-BFGS optimizer, setting up .fit() function\n",
"Using L-BFGS optimizer, setting up .fit() function\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/zz/r1py7zhj35q75v09h8_42nzh0000gp/T/ipykernel_67121/1263996749.py:4: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.\n",
" cmap = mpl.cm.get_cmap(\"Set1\")\n"
]
},
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_279d5_row0_col0, #T_279d5_row0_col1, #T_279d5_row0_col2, #T_279d5_row0_col3, #T_279d5_row0_col4, #T_279d5_row0_col5, #T_279d5_row1_col0, #T_279d5_row1_col1, #T_279d5_row1_col2, #T_279d5_row1_col3, #T_279d5_row1_col4, #T_279d5_row1_col5, #T_279d5_row2_col0, #T_279d5_row2_col1, #T_279d5_row2_col2, #T_279d5_row2_col3, #T_279d5_row2_col4, #T_279d5_row2_col5, #T_279d5_row3_col0, #T_279d5_row3_col1, #T_279d5_row3_col2, #T_279d5_row3_col3, #T_279d5_row3_col4, #T_279d5_row3_col5, #T_279d5_row4_col0, #T_279d5_row4_col1, #T_279d5_row4_col2, #T_279d5_row4_col3, #T_279d5_row4_col4, #T_279d5_row4_col5, #T_279d5_row5_col0, #T_279d5_row5_col1, #T_279d5_row5_col2, #T_279d5_row5_col3, #T_279d5_row5_col4, #T_279d5_row5_col5 {\n",
" background-color: #e41a1c;\n",
"}\n",
"#T_279d5_row6_col0, #T_279d5_row6_col1, #T_279d5_row6_col2, #T_279d5_row6_col3, #T_279d5_row6_col4, #T_279d5_row6_col5, #T_279d5_row7_col0, #T_279d5_row7_col1, #T_279d5_row7_col2, #T_279d5_row7_col3, #T_279d5_row7_col4, #T_279d5_row7_col5, #T_279d5_row8_col0, #T_279d5_row8_col1, #T_279d5_row8_col2, #T_279d5_row8_col3, #T_279d5_row8_col4, #T_279d5_row8_col5, #T_279d5_row9_col0, #T_279d5_row9_col1, #T_279d5_row9_col2, #T_279d5_row9_col3, #T_279d5_row9_col4, #T_279d5_row9_col5, #T_279d5_row10_col0, #T_279d5_row10_col1, #T_279d5_row10_col2, #T_279d5_row10_col3, #T_279d5_row10_col4, #T_279d5_row10_col5, #T_279d5_row11_col0, #T_279d5_row11_col1, #T_279d5_row11_col2, #T_279d5_row11_col3, #T_279d5_row11_col4, #T_279d5_row11_col5 {\n",
" background-color: #377eb8;\n",
"}\n",
"#T_279d5_row12_col0, #T_279d5_row12_col1, #T_279d5_row12_col2, #T_279d5_row12_col3, #T_279d5_row12_col4, #T_279d5_row12_col5, #T_279d5_row13_col0, #T_279d5_row13_col1, #T_279d5_row13_col2, #T_279d5_row13_col3, #T_279d5_row13_col4, #T_279d5_row13_col5, #T_279d5_row14_col0, #T_279d5_row14_col1, #T_279d5_row14_col2, #T_279d5_row14_col3, #T_279d5_row14_col4, #T_279d5_row14_col5, #T_279d5_row15_col0, #T_279d5_row15_col1, #T_279d5_row15_col2, #T_279d5_row15_col3, #T_279d5_row15_col4, #T_279d5_row15_col5, #T_279d5_row16_col0, #T_279d5_row16_col1, #T_279d5_row16_col2, #T_279d5_row16_col3, #T_279d5_row16_col4, #T_279d5_row16_col5, #T_279d5_row17_col0, #T_279d5_row17_col1, #T_279d5_row17_col2, #T_279d5_row17_col3, #T_279d5_row17_col4, #T_279d5_row17_col5 {\n",
" background-color: #4daf4a;\n",
"}\n",
"</style>\n",
"<table id=\"T_279d5\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_279d5_level0_col0\" class=\"col_heading level0 col0\" >Latent Class</th>\n",
" <th id=\"T_279d5_level0_col1\" class=\"col_heading level0 col1\" >Coefficient Name</th>\n",
" <th id=\"T_279d5_level0_col2\" class=\"col_heading level0 col2\" >Coefficient Estimation</th>\n",
" <th id=\"T_279d5_level0_col3\" class=\"col_heading level0 col3\" >Std. Err</th>\n",
" <th id=\"T_279d5_level0_col4\" class=\"col_heading level0 col4\" >z_value</th>\n",
" <th id=\"T_279d5_level0_col5\" class=\"col_heading level0 col5\" >P(.>z)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
" <td id=\"T_279d5_row0_col0\" class=\"data row0 col0\" >0</td>\n",
" <td id=\"T_279d5_row0_col1\" class=\"data row0 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row0_col2\" class=\"data row0 col2\" >-0.675645</td>\n",
" <td id=\"T_279d5_row0_col3\" class=\"data row0 col3\" >0.023987</td>\n",
" <td id=\"T_279d5_row0_col4\" class=\"data row0 col4\" >-28.167109</td>\n",
" <td id=\"T_279d5_row0_col5\" class=\"data row0 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
" <td id=\"T_279d5_row1_col0\" class=\"data row1 col0\" >0</td>\n",
" <td id=\"T_279d5_row1_col1\" class=\"data row1 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row1_col2\" class=\"data row1 col2\" >-0.060604</td>\n",
" <td id=\"T_279d5_row1_col3\" class=\"data row1 col3\" >0.008162</td>\n",
" <td id=\"T_279d5_row1_col4\" class=\"data row1 col4\" >-7.424849</td>\n",
" <td id=\"T_279d5_row1_col5\" class=\"data row1 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
" <td id=\"T_279d5_row2_col0\" class=\"data row2 col0\" >0</td>\n",
" <td id=\"T_279d5_row2_col1\" class=\"data row2 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row2_col2\" class=\"data row2 col2\" >1.851951</td>\n",
" <td id=\"T_279d5_row2_col3\" class=\"data row2 col3\" >0.054914</td>\n",
" <td id=\"T_279d5_row2_col4\" class=\"data row2 col4\" >33.724579</td>\n",
" <td id=\"T_279d5_row2_col5\" class=\"data row2 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
" <td id=\"T_279d5_row3_col0\" class=\"data row3 col0\" >0</td>\n",
" <td id=\"T_279d5_row3_col1\" class=\"data row3 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row3_col2\" class=\"data row3 col2\" >1.322549</td>\n",
" <td id=\"T_279d5_row3_col3\" class=\"data row3 col3\" >0.048159</td>\n",
" <td id=\"T_279d5_row3_col4\" class=\"data row3 col4\" >27.462420</td>\n",
" <td id=\"T_279d5_row3_col5\" class=\"data row3 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
" <td id=\"T_279d5_row4_col0\" class=\"data row4 col0\" >0</td>\n",
" <td id=\"T_279d5_row4_col1\" class=\"data row4 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row4_col2\" class=\"data row4 col2\" >-5.857089</td>\n",
" <td id=\"T_279d5_row4_col3\" class=\"data row4 col3\" >0.191162</td>\n",
" <td id=\"T_279d5_row4_col4\" class=\"data row4 col4\" >-30.639460</td>\n",
" <td id=\"T_279d5_row4_col5\" class=\"data row4 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row5\" class=\"row_heading level0 row5\" >5</th>\n",
" <td id=\"T_279d5_row5_col0\" class=\"data row5 col0\" >0</td>\n",
" <td id=\"T_279d5_row5_col1\" class=\"data row5 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row5_col2\" class=\"data row5 col2\" >-6.513206</td>\n",
" <td id=\"T_279d5_row5_col3\" class=\"data row5 col3\" >0.195680</td>\n",
" <td id=\"T_279d5_row5_col4\" class=\"data row5 col4\" >-33.285046</td>\n",
" <td id=\"T_279d5_row5_col5\" class=\"data row5 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row6\" class=\"row_heading level0 row6\" >6</th>\n",
" <td id=\"T_279d5_row6_col0\" class=\"data row6 col0\" >1</td>\n",
" <td id=\"T_279d5_row6_col1\" class=\"data row6 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row6_col2\" class=\"data row6 col2\" >-1.817566</td>\n",
" <td id=\"T_279d5_row6_col3\" class=\"data row6 col3\" >0.077771</td>\n",
" <td id=\"T_279d5_row6_col4\" class=\"data row6 col4\" >-23.370796</td>\n",
" <td id=\"T_279d5_row6_col5\" class=\"data row6 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row7\" class=\"row_heading level0 row7\" >7</th>\n",
" <td id=\"T_279d5_row7_col0\" class=\"data row7 col0\" >1</td>\n",
" <td id=\"T_279d5_row7_col1\" class=\"data row7 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row7_col2\" class=\"data row7 col2\" >-1.726365</td>\n",
" <td id=\"T_279d5_row7_col3\" class=\"data row7 col3\" >0.058838</td>\n",
" <td id=\"T_279d5_row7_col4\" class=\"data row7 col4\" >-29.340986</td>\n",
" <td id=\"T_279d5_row7_col5\" class=\"data row7 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row8\" class=\"row_heading level0 row8\" >8</th>\n",
" <td id=\"T_279d5_row8_col0\" class=\"data row8 col0\" >1</td>\n",
" <td id=\"T_279d5_row8_col1\" class=\"data row8 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row8_col2\" class=\"data row8 col2\" >3.696567</td>\n",
" <td id=\"T_279d5_row8_col3\" class=\"data row8 col3\" >0.160258</td>\n",
" <td id=\"T_279d5_row8_col4\" class=\"data row8 col4\" >23.066404</td>\n",
" <td id=\"T_279d5_row8_col5\" class=\"data row8 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row9\" class=\"row_heading level0 row9\" >9</th>\n",
" <td id=\"T_279d5_row9_col0\" class=\"data row9 col0\" >1</td>\n",
" <td id=\"T_279d5_row9_col1\" class=\"data row9 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row9_col2\" class=\"data row9 col2\" >4.111840</td>\n",
" <td id=\"T_279d5_row9_col3\" class=\"data row9 col3\" >0.157179</td>\n",
" <td id=\"T_279d5_row9_col4\" class=\"data row9 col4\" >26.160225</td>\n",
" <td id=\"T_279d5_row9_col5\" class=\"data row9 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row10\" class=\"row_heading level0 row10\" >10</th>\n",
" <td id=\"T_279d5_row10_col0\" class=\"data row10 col0\" >1</td>\n",
" <td id=\"T_279d5_row10_col1\" class=\"data row10 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row10_col2\" class=\"data row10 col2\" >-26.693516</td>\n",
" <td id=\"T_279d5_row10_col3\" class=\"data row10 col3\" >3.274723</td>\n",
" <td id=\"T_279d5_row10_col4\" class=\"data row10 col4\" >-8.151381</td>\n",
" <td id=\"T_279d5_row10_col5\" class=\"data row10 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row11\" class=\"row_heading level0 row11\" >11</th>\n",
" <td id=\"T_279d5_row11_col0\" class=\"data row11 col0\" >1</td>\n",
" <td id=\"T_279d5_row11_col1\" class=\"data row11 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row11_col2\" class=\"data row11 col2\" >-14.925840</td>\n",
" <td id=\"T_279d5_row11_col3\" class=\"data row11 col3\" >0.634699</td>\n",
" <td id=\"T_279d5_row11_col4\" class=\"data row11 col4\" >-23.516403</td>\n",
" <td id=\"T_279d5_row11_col5\" class=\"data row11 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row12\" class=\"row_heading level0 row12\" >12</th>\n",
" <td id=\"T_279d5_row12_col0\" class=\"data row12 col0\" >2</td>\n",
" <td id=\"T_279d5_row12_col1\" class=\"data row12 col1\" >Weights_items_features_0</td>\n",
" <td id=\"T_279d5_row12_col2\" class=\"data row12 col2\" >-2.104791</td>\n",
" <td id=\"T_279d5_row12_col3\" class=\"data row12 col3\" >0.104296</td>\n",
" <td id=\"T_279d5_row12_col4\" class=\"data row12 col4\" >-20.181009</td>\n",
" <td id=\"T_279d5_row12_col5\" class=\"data row12 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row13\" class=\"row_heading level0 row13\" >13</th>\n",
" <td id=\"T_279d5_row13_col0\" class=\"data row13 col0\" >2</td>\n",
" <td id=\"T_279d5_row13_col1\" class=\"data row13 col1\" >Weights_items_features_1</td>\n",
" <td id=\"T_279d5_row13_col2\" class=\"data row13 col2\" >-1.652622</td>\n",
" <td id=\"T_279d5_row13_col3\" class=\"data row13 col3\" >0.073820</td>\n",
" <td id=\"T_279d5_row13_col4\" class=\"data row13 col4\" >-22.387188</td>\n",
" <td id=\"T_279d5_row13_col5\" class=\"data row13 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row14\" class=\"row_heading level0 row14\" >14</th>\n",
" <td id=\"T_279d5_row14_col0\" class=\"data row14 col0\" >2</td>\n",
" <td id=\"T_279d5_row14_col1\" class=\"data row14 col1\" >Weights_items_features_2</td>\n",
" <td id=\"T_279d5_row14_col2\" class=\"data row14 col2\" >-5.554287</td>\n",
" <td id=\"T_279d5_row14_col3\" class=\"data row14 col3\" >0.245318</td>\n",
" <td id=\"T_279d5_row14_col4\" class=\"data row14 col4\" >-22.641151</td>\n",
" <td id=\"T_279d5_row14_col5\" class=\"data row14 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row15\" class=\"row_heading level0 row15\" >15</th>\n",
" <td id=\"T_279d5_row15_col0\" class=\"data row15 col0\" >2</td>\n",
" <td id=\"T_279d5_row15_col1\" class=\"data row15 col1\" >Weights_items_features_3</td>\n",
" <td id=\"T_279d5_row15_col2\" class=\"data row15 col2\" >-13.565555</td>\n",
" <td id=\"T_279d5_row15_col3\" class=\"data row15 col3\" >0.544168</td>\n",
" <td id=\"T_279d5_row15_col4\" class=\"data row15 col4\" >-24.928965</td>\n",
" <td id=\"T_279d5_row15_col5\" class=\"data row15 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row16\" class=\"row_heading level0 row16\" >16</th>\n",
" <td id=\"T_279d5_row16_col0\" class=\"data row16 col0\" >2</td>\n",
" <td id=\"T_279d5_row16_col1\" class=\"data row16 col1\" >Weights_items_features_4</td>\n",
" <td id=\"T_279d5_row16_col2\" class=\"data row16 col2\" >-9.794930</td>\n",
" <td id=\"T_279d5_row16_col3\" class=\"data row16 col3\" >0.631004</td>\n",
" <td id=\"T_279d5_row16_col4\" class=\"data row16 col4\" >-15.522781</td>\n",
" <td id=\"T_279d5_row16_col5\" class=\"data row16 col5\" >0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_279d5_level0_row17\" class=\"row_heading level0 row17\" >17</th>\n",
" <td id=\"T_279d5_row17_col0\" class=\"data row17 col0\" >2</td>\n",
" <td id=\"T_279d5_row17_col1\" class=\"data row17 col1\" >Weights_items_features_5</td>\n",
" <td id=\"T_279d5_row17_col2\" class=\"data row17 col2\" >-12.126673</td>\n",
" <td id=\"T_279d5_row17_col3\" class=\"data row17 col3\" >0.681118</td>\n",
" <td id=\"T_279d5_row17_col4\" class=\"data row17 col4\" >-17.804060</td>\n",
" <td id=\"T_279d5_row17_col5\" class=\"data row17 col5\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f7d56306970>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"report = lc_model.compute_report(elec_dataset)\n",
"\n",
"def format_color_groups(df):\n",
" cmap = mpl.cm.get_cmap(\"Set1\")\n",
" colors = [mpl.colors.rgb2hex(cmap(i)) for i in range(cmap.N)]\n",
" x = df.copy()\n",
" factors = list(x['Latent Class'].unique())\n",
" i = 0\n",
" for factor in factors:\n",
" style = f'background-color: {colors[i]}'\n",
" x.loc[x['Latent Class'] == factor, :] = style\n",
" i += 1\n",
" return x\n",
"\n",
"report.style.apply(format_color_groups, axis=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"keep_output": true
},
"source": [
"## Latent Conditional Logit\n",
"We used a very simple MNL. Here we simulate the same MNL, by using the Conditional-Logit formulation.\\\n",
Expand Down Expand Up @@ -281,7 +527,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tf_env",
"display_name": "choice_learn",
"language": "python",
"name": "python3"
},
Expand All @@ -295,7 +541,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.8.18"
}
},
"nbformat": 4,
Expand Down
Loading
Loading