-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgenerate_explainers.py
58 lines (52 loc) · 1.71 KB
/
generate_explainers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from pathlib import Path
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from explainerdashboard import (
ClassifierExplainer,
RegressionExplainer,
ExplainerDashboard,
)
from explainerdashboard.datasets import *
pkl_dir = Path.cwd() / "pkls"
# classifier
print("Generating titanic explainers")
print("Generating classifier explainer")
X_train, y_train, X_test, y_test = titanic_survive()
model = RandomForestClassifier(n_estimators=50, max_depth=5).fit(X_train, y_train)
clas_explainer = ClassifierExplainer(
model,
X_test,
y_test,
cats=["Sex", "Deck", "Embarked"],
descriptions=feature_descriptions,
labels=["Not survived", "Survived"],
)
_ = ExplainerDashboard(clas_explainer)
clas_explainer.dump(pkl_dir / "clas_explainer.joblib")
# regression
print("Generating titanic fare explainer")
X_train, y_train, X_test, y_test = titanic_fare()
model = RandomForestRegressor(n_estimators=50, max_depth=5).fit(X_train, y_train)
reg_explainer = RegressionExplainer(
model,
X_test,
y_test,
cats=["Sex", "Deck", "Embarked"],
descriptions=feature_descriptions,
units="$",
)
_ = ExplainerDashboard(reg_explainer)
reg_explainer.dump(pkl_dir / "reg_explainer.joblib")
# multiclass
print("Generating titanic embarked multiclass explainer")
X_train, y_train, X_test, y_test = titanic_embarked()
model = RandomForestClassifier(n_estimators=50, max_depth=5).fit(X_train, y_train)
multi_explainer = ClassifierExplainer(
model,
X_test,
y_test,
cats=["Sex", "Deck"],
descriptions=feature_descriptions,
labels=["Queenstown", "Southampton", "Cherbourg"],
)
_ = ExplainerDashboard(multi_explainer)
multi_explainer.dump(pkl_dir / "multi_explainer.joblib")