Skip to content

Commit

Permalink
Add EM vs GS Results
Browse files Browse the repository at this point in the history
  • Loading branch information
TheReconPilot committed May 6, 2022
1 parent c68e487 commit 49a8d66
Show file tree
Hide file tree
Showing 4 changed files with 1,198 additions and 0 deletions.
196 changes: 196 additions & 0 deletions EMvsGS/EMvsGS-Results.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "b361579d",
"metadata": {},
"source": [
"# EM vs Gibbs Sampler - Results\n",
"\n",
"This is an experiment to compare performance of Expectation Maximization (EM) and Gibbs Sampler (GS) in the context of Gaussian Mixture Models.\n",
"\n",
"- 500 runs each for K = 3 and K = 6 clusters\n",
"- 1000 data points in each\n",
"- Univariate\n",
"\n",
"During Data Generation, Means were generated from a Uniform [-10, 10] distribution. Standard Deviations were generated from a Uniform [0.25, 5] distribution."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a82ede1a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import plotly.express as px\n",
"import plotly.graph_objects as go\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "253f623f",
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"# Load Data\n",
"gs3 = pd.read_csv(\"gs-k-3.csv\")\n",
"gs6 = pd.read_csv(\"gs-k-6.csv\")\n",
"\n",
"em3 = pd.read_csv(\"em-k-3.csv\")\n",
"em6 = pd.read_csv(\"em-k-6.csv\")"
]
},
{
"cell_type": "markdown",
"id": "61a41812",
"metadata": {},
"source": [
"## The Data\n",
"\n",
"Gibbs Sampler Results have the following data\n",
"\n",
"- RS: Rand Score\n",
"- ARS: Adjusted Rand Score\n",
"- SS: Silhouette Score\n",
"\n",
"for each of the three methods\n",
"\n",
"- Base GS\n",
"- GS with Multiple Initializations\n",
"- GS with Burn In\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fd66639",
"metadata": {},
"outputs": [],
"source": [
"# GS with K = 3\n",
"gs3"
]
},
{
"cell_type": "markdown",
"id": "f536009d",
"metadata": {},
"source": [
"The dataframe for EM has the Adjusted Rand Score (ARS) results for EM in 2 modes:\n",
"\n",
"- EM with Many Random Initializations (`gmm_mri_ars`)\n",
"- EM with K-Means Initialization (`gmm_kmeans_ars`)\n",
"\n",
"And the final column is the standard K-Means Clustering results (`kmeans_ars`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c605b756",
"metadata": {},
"outputs": [],
"source": [
"# EM with K = 3\n",
"em3"
]
},
{
"cell_type": "markdown",
"id": "1f0f9611",
"metadata": {},
"source": [
"## Results\n",
"\n",
"The plots are interactive.\n",
"\n",
"### K = 3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca13161d",
"metadata": {},
"outputs": [],
"source": [
"fig = go.Figure()\n",
"fig.add_trace(go.Box(y=gs3['gs_base_ars'], name=\"GS Base\"))\n",
"fig.add_trace(go.Box(y=gs3['gs_burnin_ars'], name=\"GS Burn In\"))\n",
"fig.add_trace(go.Box(y=gs3['gs_multi_ars'], name=\"GS Multi\"))\n",
"fig.add_trace(go.Box(y=em3['gmm_mri_ars'], name=\"EM Multi Init\"))\n",
"fig.add_trace(go.Box(y=em3['gmm_kmeans_ars'], name=\"EM K-Means Init\"))\n",
"fig.add_trace(go.Box(y=em3['kmeans_ars'], name=\"Standard K-Means\"))\n",
"fig.update_layout(title_text=\"K = 3\")\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "c9ad9d65",
"metadata": {},
"source": [
"### K = 6"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4435abe",
"metadata": {},
"outputs": [],
"source": [
"fig = go.Figure()\n",
"fig.add_trace(go.Box(y=gs6['gs_base_ars'], name=\"GS Base\"))\n",
"fig.add_trace(go.Box(y=gs6['gs_burnin_ars'], name=\"GS Burn In\"))\n",
"fig.add_trace(go.Box(y=gs6['gs_multi_ars'], name=\"GS Multi\"))\n",
"fig.add_trace(go.Box(y=em6['gmm_mri_ars'], name=\"EM Multi Init\"))\n",
"fig.add_trace(go.Box(y=em6['gmm_kmeans_ars'], name=\"EM K-Means Init\"))\n",
"fig.add_trace(go.Box(y=em6['kmeans_ars'], name=\"Standard K-Means\"))\n",
"fig.update_layout(title_text=\"K = 6\")\n",
"fig.show()"
]
},
{
"cell_type": "markdown",
"id": "434a31f7",
"metadata": {},
"source": [
"## Conclusions\n",
"\n",
"Gibbs Sampler with Multi Init, and both of the EM versions perform better than the standard K-Means.\n",
"\n",
"Between GS and EM, GS with Multi Init seems to be performing the best, by a slight margin over EM."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c85e8401",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all",
"text_representation": {
"extension": ".py",
"format_name": "percent"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 49a8d66

Please sign in to comment.