Skip to content

Commit

Permalink
add feature importance plot to experiment artifact
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Apr 16, 2024
1 parent 03cafba commit 614f61e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 6 deletions.
24 changes: 24 additions & 0 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,30 @@ def _encoder(self, value: ColumnTransformer):
"""
self["encoder"] = value

@property
def feature_importance(self) -> pd.DataFrame:
"""
Get the feature_importance from the DataContainer.
Returns
-------
pd.DataFrame
A DataFrame containing feature importance values.
"""
return self["feature_importance"]

@feature_importance.setter
def feature_importance(self, value: pd.DataFrame):
"""
Set the feature_importance in the DataContainer.
Parameters
----------
value
The feature importance DataFrame to be stored in the DataContainer.
"""
self["feature_importance"] = value

def __eq__(self, other) -> bool:
"""
Compare this DataContainer with another for equality.
Expand Down
22 changes: 20 additions & 2 deletions pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from datetime import datetime
from typing import Any, Optional

import matplotlib.pyplot as plt
import mlflow
import pandas as pd
from mlflow.data import from_pandas

from pipeline_lib.core.data_container import DataContainer
Expand Down Expand Up @@ -185,6 +187,17 @@ def log_params_from_config(config):
for key, value in step.get("parameters", {}).items():
mlflow.log_param(f"pipeline.steps_{i}.parameters.{key}", value)

def plot_feature_importance(df: pd.DataFrame) -> None:
fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(df["feature"], df["importance"])
ax.set_xlabel("Importance")
ax.set_ylabel("Feature")
ax.set_title("Feature Importance")
# add grid
ax.grid(axis="x")
plt.tight_layout()
mlflow.log_figure(fig, "feature_importance.png")

mlflow.set_experiment(experiment_name)

if not run_name:
Expand All @@ -205,9 +218,9 @@ def log_params_from_config(config):
self.logger.info(f"Logging input data to MLflow with dataset name: {dataset_name}")
mlflow.log_input(from_pandas(data.raw), dataset_name)

# Log the training metrics
# Log prediction metrics
if data.metrics:
self.logger.debug("Logging the training metrics to MLflow")
self.logger.debug("Logging prediction metrics to MLflow")
for metric_name, metric_value in data.metrics["prediction"].items():
mlflow.log_metric(metric_name, metric_value)

Expand All @@ -216,6 +229,11 @@ def log_params_from_config(config):
self.logger.debug("Logging the model to MLflow")
mlflow.sklearn.log_model(data.model, artifact_path="model")

# Save feature importance as a png artifact
if data.feature_importance is not None:
self.logger.debug("Plotting feature importance for MLflow")
plot_feature_importance(data.feature_importance)

def save_run(
self,
data: DataContainer,
Expand Down
59 changes: 56 additions & 3 deletions pipeline_lib/core/steps/calculate_reports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Optional
import time

import numpy as np
import pandas as pd
import shap
from tqdm import tqdm

from pipeline_lib.core import DataContainer
from pipeline_lib.core.steps.base import PipelineStep
Expand All @@ -10,12 +15,60 @@ class CalculateReportsStep(PipelineStep):
used_for_prediction = True
used_for_training = False

def __init__(self, config: Optional[dict] = None) -> None:
def __init__(self, max_samples: int = 1000) -> None:
"""Initialize CalculateReportsStep."""
super().__init__(config=config)
self.init_logger()
self.max_samples = max_samples

def execute(self, data: DataContainer) -> DataContainer:
"""Execute the step."""
self.logger.info("Calculating reports")

model = data.model
if model is None:
raise ValueError("Model not found in data container.")

df = data.flow
if len(df) > self.max_samples:
# Randomly sample a subset of data points if the dataset is larger than max_samples
self.logger.info(
f"Dataset contains {len(df)} data points and max_samples is set to"
f" {self.max_samples}."
)
self.logger.info(f"Sampling {self.max_samples} data points from the dataset.")
df = df.sample(n=self.max_samples, random_state=42)

drop_columns = (
data._drop_columns + ["predictions"] if data._drop_columns else ["predictions"]
)
df = df.drop(columns=drop_columns)
X = df.drop(columns=[data.target])

# Calculate SHAP values with progress tracking and logging
explainer = shap.TreeExplainer(model.model)
shap_values = []
# shap_base_value = explainer.expected_value
total_rows = len(X)
start_time = time.time()
with tqdm(total=total_rows, desc="Calculating SHAP values") as pbar:
for i in range(total_rows):
shap_value = explainer.shap_values(X.iloc[[i]])
shap_values.append(shap_value[0]) # Append only the first element of shap_value
pbar.update(1)
elapsed_time = time.time() - start_time
remaining_time = (elapsed_time / (i + 1)) * (total_rows - i - 1)
pbar.set_postfix(elapsed=f"{elapsed_time:.2f}s", remaining=f"{remaining_time:.2f}s")

shap_values = np.array(shap_values) # Convert shap_values to a NumPy array

feature_names = X.columns.tolist()
feature_importance = pd.DataFrame(
list(zip(feature_names, abs(shap_values).mean(0))),
columns=["feature", "importance"],
)
feature_importance = feature_importance.sort_values(by="importance", ascending=True)
feature_importance.reset_index(drop=True, inplace=True)

data.feature_importance = feature_importance

return data
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ xgboost = { version = "^2.0.3", optional = true }
category-encoders = "^2.6.3"
tqdm = "^4.66.2"
mlflow = "^2.11.3"
matplotlib = "^3.8.4"

[tool.poetry.extras]
xgboost = ["xgboost"]
Expand Down

0 comments on commit 614f61e

Please sign in to comment.