From c8a82c22b046e2a4ad2dbeac5fa8dc18da705f58 Mon Sep 17 00:00:00 2001 From: VARUNSHIYAM <138989960+Varunshiyam@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:03:15 +0530 Subject: [PATCH 1/2] Fixes Mushroom Classification --- .../mushroom-classification-notebook.ipynb | 1020 +++++++++++++++++ 1 file changed, 1020 insertions(+) create mode 100644 Prediction Models/Mushroom_Classification/mushroom-classification-notebook.ipynb diff --git a/Prediction Models/Mushroom_Classification/mushroom-classification-notebook.ipynb b/Prediction Models/Mushroom_Classification/mushroom-classification-notebook.ipynb new file mode 100644 index 00000000..4860f39c --- /dev/null +++ b/Prediction Models/Mushroom_Classification/mushroom-classification-notebook.ipynb @@ -0,0 +1,1020 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mushroom Classification\n", + "\n", + "## Modelling Objective\n", + "Build a **Simple** and **Interpretable** Model to Perform **Binary Classification** on Edibility of Mushroom from *Agarcius and Lepiota Family*. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:45.398443Z", + "iopub.status.busy": "2021-06-04T17:10:45.398122Z", + "iopub.status.idle": "2021-06-04T17:10:45.403054Z", + "shell.execute_reply": "2021-06-04T17:10:45.401917Z", + "shell.execute_reply.started": "2021-06-04T17:10:45.398415Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import scipy.stats as ss\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "from collections import Counter\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reading Dataset\n", + "\n", + "The [mushroom dataset](https://archive.ics.uci.edu/ml/datasets/mushroom) includes descriptions of hypothetical samples corresponding to 23\n", + "species of gilled mushrooms in the Agaricus and Lepiota Family (pp. 500-525). \n", + "\n", + "Each species is identified as definitely *edible, definitely poisonous, or of unknown\n", + "edibility and not recommended*. This latter class was combined with the poisonous\n", + "one. \n", + "\n", + "Hence, the task given is a binary classification problem whereby, \n", + "given the features of mushrooms, we are to classify the mushrooms \n", + "into **p=Poisonous** or **e=edible**.\n", + "\n", + "## Data Dictionary\n", + "| Columns | Descriptions |\n", + "| :--- | :--- |\n", + "| class | poisonous=p, edible=e| \n", + "| cap-shape | bell=b,conical=c,convex=x,flat=f, knobbed=k,sunken=s|\n", + "| cap-surface | fibrous=f,grooves=g,scaly=y,smooth=s |\n", + "| cap-color | brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y|\n", + "| bruises | bruises=t,no=f |\n", + "| odor | almond=a,anise=l,creosote=c,fishy=y,foul=f, musty=m,none=n,pungent=p,spicy=s |\n", + "| gill-attachment | attached=a,descending=d,free=f,notched=n|\n", + "| gill-spacing | close=c,crowded=w,distant=d|\n", + "| gill-size | broad=b,narrow=n |\n", + "| gill-color | black=k,brown=n,buff=b,chocolate=h,gray=g,green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y|\n", + "| stalk-shape | enlarging=e,tapering=t\n", + "| stalk-root | bulbous=b,club=c,cup=u,equal=e, rhizomorphs=z,rooted=r,**missing=?** |\n", + "| stalk-surface-above-ring| fibrous=f,scaly=y,silky=k,smooth=s|\n", + "| stalk-surface-below-ring| fibrous=f,scaly=y,silky=k,smooth=s|\n", + "| stalk-color-above-ring | brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y|\n", + "| stalk-color-below-ring | brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y|\n", + "| veil-type | partial=p,universal=u|\n", + "| veil-color | brown=n,orange=o,white=w,yellow=y|\n", + "| ring-number | none=n,one=o,two=t|\n", + "| ring-type | cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z|\n", + "| spore-print-color | black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y|\n", + "| population | abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y|\n", + "| habitat | grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d|" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:45.410235Z", + "iopub.status.busy": "2021-06-04T17:10:45.409894Z", + "iopub.status.idle": "2021-06-04T17:10:45.458726Z", + "shell.execute_reply": "2021-06-04T17:10:45.457527Z", + "shell.execute_reply.started": "2021-06-04T17:10:45.410206Z" + } + }, + "outputs": [], + "source": [ + "mushroom_df = pd.read_csv(\"../input/mushroom-classification/mushrooms.csv\", \n", + " na_values=\"?\", # masking \"?\" with Null Values\n", + " )\n", + "mushroom_df.rename(columns = {\"class\":\"is-edible\"}, inplace = True)\n", + "mushroom_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exploratory Data Analysis\n", + "Understands the dataset and flag out flaws in the dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Descriptive Summaries\n", + "By running `.info()` on our dataframe, the following are the initial observation\n", + "of the dataset.\n", + "\n", + "**Observations**\n", + "\n", + "1. The shape of dataset is `(8124, 23)` whereby there is 8124 observations and 23 columns. \n", + "(22 Features + 1 Target Variable: `\"is-edible\"`)\n", + "2. Datatype of all columns are `object`. However, from the documentation there are numerical feature which is encoded as string. **(e.g. ring-number)**\n", + "3. Missing values is observed in `\"stalk-root\"` columns which is around 30.5% (2480 Missing Values) of the entire dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:45.461125Z", + "iopub.status.busy": "2021-06-04T17:10:45.460712Z", + "iopub.status.idle": "2021-06-04T17:10:45.48465Z", + "shell.execute_reply": "2021-06-04T17:10:45.483776Z", + "shell.execute_reply.started": "2021-06-04T17:10:45.461082Z" + } + }, + "outputs": [], + "source": [ + "mushroom_df.info()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Unique Values Exploration\n", + "Since all our columns are encoded in string, one way to explore the values is \n", + "number of unique values in each columns.\n", + "\n", + "**Observations**\n", + "\n", + "1. Constant Value Column(1 Unique Value): \n", + "\n", + " `\"veil-type\"`\n", + " \n", + " As all datapoints have constant value = p in `\"veil-type\"`, it does not provide \n", + " any information to the target variable.\n", + " \n", + " > One approach is to **Drop the `\"veil-type\"` column entirely**.\n", + "\n", + "2. Binary Columns(2 Unique Values): \n", + "\n", + " `[\"is-edible\"(label), \"bruises\", \"gill-attachment\", \"gill-spacing\", \"gill-size\", \"stalk-shape\"]`\n", + "\n", + "3. Nominal Categorical Columns(>2 Unique Values):\n", + " \n", + " `[\"cap-shape\", \"cap-surface\", \"cap-color\", \"odor\", \"gill-color\",\n", + " \"stalk-root\", \"stalk-surface-above-ring\", \"stalk-surface-below-ring\",\n", + " \"stalk-color-above-ring\", \"stalk-color-below-ring\", \"veil-color\", \"ring-type\", \n", + " \"spore-print-color\", \"population\", \"habitat\"]`\n", + "\n", + " There are 15 Nominal Categorical which describe the characteristics of mushrooms \n", + " including the texture, colors, population and habitat.\n", + "\n", + " > As there are abundance of Nominal Categorical features, creating One-Hot variables \n", + " for all categorical features might create excessive dimensional spaces which can be \n", + " computational expenssive and prone to overfitting ([aka \"The Curse of Dimensionality\"](https://towardsdatascience.com/the-curse-of-dimensionality-50dc6e49aa1e))\n", + "\n", + " > There might be a need to explore further on feature selection and/or [dimensionality reduction](https://towardsdatascience.com/5-must-know-dimensionality-reduction-techniques-via-prince-e6ffb27e55d1) to mitigate the curse of dimensionality.\n", + "4. Discrete Numerical Columns(Countable Values):\n", + "\n", + " `\"ring-number\"`\n", + "\n", + " Although it is technically a numerical column, since the number of unique values is low(`nuique() == 3`), we can treat it as a categorical column during the encoding. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:45.485734Z", + "iopub.status.busy": "2021-06-04T17:10:45.485493Z", + "iopub.status.idle": "2021-06-04T17:10:45.518231Z", + "shell.execute_reply": "2021-06-04T17:10:45.517145Z", + "shell.execute_reply.started": "2021-06-04T17:10:45.485711Z" + } + }, + "outputs": [], + "source": [ + "mushroom_df.nunique().sort_values()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Countplot\n", + "Countplot is a convenient tool to quickly explore the count of category variables for each unique values.\n", + "Since the dataset entire dataset is made up of categorical variables, we can just generate countplot for all columns.\n", + "\n", + "**Observations**\n", + "\n", + "1. Balanced Label\n", + "\n", + " The class frequency of the target variable `is-edible` is relatively balanced with 4208 instances classified as edible and 3916 instances classified as poisonous.\n", + " \n", + "2. High Cardinality for Categorical Features\n", + "\n", + " For features >2 Unique values, most of them suffer from high cardinality with minority classes. This makes the column of resulting matrix sparse if we were to perform One-Hot Encoding without any feature selection/dimension reduction.\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:45.52001Z", + "iopub.status.busy": "2021-06-04T17:10:45.519665Z", + "iopub.status.idle": "2021-06-04T17:10:52.34298Z", + "shell.execute_reply": "2021-06-04T17:10:52.341965Z", + "shell.execute_reply.started": "2021-06-04T17:10:45.519967Z" + } + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(15,10))\n", + "for i, col in enumerate(mushroom_df.columns):\n", + " sns.set_palette(sns.color_palette(\"Paired\"))\n", + " ax = plt.subplot(6,4,i+1)\n", + " sns.countplot(\n", + " x=col, data = mushroom_df, ax = ax, \n", + " order = mushroom_df[col].value_counts(ascending=True).index\n", + " )\n", + " sns.set_style('whitegrid')\n", + " plt.xticks(rotation=90)\n", + " plt.ylabel(\"Median Price\")\n", + " plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Summary:**\n", + "\n", + "1. Constant Value Column exist for `veil-type` which shall be dropped as it does not bring any information of the target variable.\n", + "2. Data Cleaning/ Imputation is needed to treat missing values for `stalk-root`.\n", + "3. Encode the features into dummy variables for binary column(Unique Values = 2) and One-Hot encoding for nominal categorical columns(Unique Values >=2).\n", + "4. Feature Selection might be required to reduce the dimension of the dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Preprocessing\n", + "Preprocess dataset into a format that is digestible by model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Drop Constant Value Column\n", + "\n", + "`veil-type` column is dropped as it have constant value of \"p\" which does not bring any information about the target variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:52.345878Z", + "iopub.status.busy": "2021-06-04T17:10:52.345595Z", + "iopub.status.idle": "2021-06-04T17:10:52.3518Z", + "shell.execute_reply": "2021-06-04T17:10:52.350764Z", + "shell.execute_reply.started": "2021-06-04T17:10:52.34585Z" + } + }, + "outputs": [], + "source": [ + "mushroom_df.drop(columns=\"veil-type\", inplace = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Cleaning and Imputation\n", + "\n", + "Since there are around 30.2% of missing values observed for `stalk-root` feature, we can tryout the following approaches:\n", + "\n", + "1. Drop the Entire `stalk-root` Column\n", + "2. Impute with Central Tendency(most-frequent = \"b\")\n", + "3. Impute with Advanced Algorithm in SKLearn (e.g. IterativeImputer, KNNImputer)\n", + "\n", + "We will go with the first approach since it is the simplest solution that does not change the underlying distribution of dataset and evaluate the decision based on the model's performance later." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:52.355466Z", + "iopub.status.busy": "2021-06-04T17:10:52.355108Z", + "iopub.status.idle": "2021-06-04T17:10:52.375412Z", + "shell.execute_reply": "2021-06-04T17:10:52.373877Z", + "shell.execute_reply.started": "2021-06-04T17:10:52.355429Z" + } + }, + "outputs": [], + "source": [ + "mushroom_df.dropna(axis=1, inplace=True)\n", + "\n", + "print(\"stalk-root\" in mushroom_df.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:52.377259Z", + "iopub.status.busy": "2021-06-04T17:10:52.376897Z", + "iopub.status.idle": "2021-06-04T17:10:52.382523Z", + "shell.execute_reply": "2021-06-04T17:10:52.381458Z", + "shell.execute_reply.started": "2021-06-04T17:10:52.377228Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.feature_selection import chi2, RFECV\n", + "from sklearn.model_selection import train_test_split, cross_validate, GridSearchCV\n", + "from sklearn.dummy import DummyClassifier\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.svm import SVC, LinearSVC\n", + "from sklearn.tree import DecisionTreeClassifier, plot_tree\n", + "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n", + "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, auc, roc_curve, plot_roc_curve" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train-Test-Split\n", + "Splitting Data into Training and Testing Set before in-depth EDA and Preprocessing to avoid data leakage and ensures all decisions make are based on the training set and the testing set is left untouched." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2021-06-04T17:10:52.383831Z", + "iopub.status.busy": "2021-06-04T17:10:52.38358Z", + "iopub.status.idle": "2021-06-04T17:10:52.400667Z", + "shell.execute_reply": "2021-06-04T17:10:52.399538Z", + "shell.execute_reply.started": "2021-06-04T17:10:52.383807Z" + } + }, + "outputs": [], + "source": [ + "train_df, test_df = train_test_split(mushroom_df, test_size = 0.3, random_state = 12)\n", + "print(train_df.shape)\n", + "print(test_df.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Feature Selection\n", + "Although we have a total of 107 features after feature encoding, some of the feature might not be useful for modelling as it have too little occurrence, or they are just noises that does not bring any information about the target variable.\n", + "\n", + "For that, feature selection is needed to investigate more on the strong and weak features and how we could perform some feature engineering before we start our modelling.\n", + "\n", + "*All investigation and inference is made with the training set to minimize any data leakage which leads to biased result during model evaluation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cramer's V Correlation Matrix\n", + "Cramer's V is a statistical test to calculate correlation in tables which have more than 2x2 rows and columns. It is used as post-test to determine strengths of association after chi-square has determined significance. \n", + "\n", + "$$\n", + "V = \\sqrt{\\frac{\\chi^2/n}{k-1}}\\\\\n", + "\\chi^2 : \\text{chi-square}\\\\\n", + "k : \\text{number of rows or columns in the contingency table}\\\\\n", + "n : \\text{Number of observations}\n", + "\\\\\n", + "(Weak Association)0